diff --git a/progress/SpecForge/.devcontainer/Dockerfile b/progress/SpecForge/.devcontainer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..8ffb0d0328f12064b311869d19aa60df32cd7484 --- /dev/null +++ b/progress/SpecForge/.devcontainer/Dockerfile @@ -0,0 +1,32 @@ +FROM lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/progress/SpecForge/.devcontainer/devcontainer.json b/progress/SpecForge/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000000000000000000000000000000..b2dbad2a745763b273af79b742640461f18b7894 --- /dev/null +++ b/progress/SpecForge/.devcontainer/devcontainer.json @@ -0,0 +1,30 @@ +{ + "name": "sglang", + "build": { + "dockerfile": "Dockerfile" + }, + "remoteUser": "devuser", + "customizations": { + "vscode": { + "extensions": [ + // Python development + "ms-python.python", + "charliermarsh.ruff", + // Rust development + "rust-lang.rust-analyzer", + "tamasfe.even-better-toml" + ] + } + }, + "forwardPorts": [], + "runArgs": [ + "--gpus", + "all" + ], + // The two lines below ensures that your local changes in the sglang + // repo is automatically synced to the sglang pip package installed + // in the dev docker container. You can remove / comment out these + // two lines if you prefer to sync code changes manually. + "workspaceMount": "source=${localWorkspaceFolder},target=/sgl-workspace/specforge,type=bind", + "workspaceFolder": "/sgl-workspace/specforge" +} diff --git a/progress/SpecForge/.github/CODEOWNERS b/progress/SpecForge/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..e4dbc44f0f9b24da1ad6a96eff14abe45f184255 --- /dev/null +++ b/progress/SpecForge/.github/CODEOWNERS @@ -0,0 +1,11 @@ +.github @FrankLeeeee +/specforge/core @FrankLeeeee +/specforge/data @zyksir @sleepcoo @shuaills +/specforge/layers @FrankLeeeee @FlamingoPg @sleepcoo @shuaills +/specforge/modeling @FlamingoPg @sleepcoo @shuaills @FrankLeeeee +/tests @FrankLeeeee +/assets @FrankLeeeee @zhyncs +/examples @shuaills @sleepcoo @FlamingoPg +/configs @FrankLeeeee @FlamingoPg +/benchmarks @FrankLeeeee +/scripts @shuaills @sleepcoo @FlamingoPg diff --git a/progress/SpecForge/.github/pull_request_template.md b/progress/SpecForge/.github/pull_request_template.md new file mode 100644 index 0000000000000000000000000000000000000000..296468dfb8c84c38784759283db598959572a91f --- /dev/null +++ b/progress/SpecForge/.github/pull_request_template.md @@ -0,0 +1,30 @@ + + +## Motivation + + + +## Modifications + + + +## Related Issues + + + +## Accuracy Test + + + +## Benchmark & Profiling + + + +## Checklist + +- [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). +- [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). +- [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html). +- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR. +- [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR. diff --git a/progress/SpecForge/assets/logo.svg b/progress/SpecForge/assets/logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..7f619f50a0be61ade41e82599a40db2a45b3c376 --- /dev/null +++ b/progress/SpecForge/assets/logo.svg @@ -0,0 +1,938 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/progress/SpecForge/benchmarks/README.md b/progress/SpecForge/benchmarks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..678ce7a257b8bddb2a44373a4eba1ff77595813a --- /dev/null +++ b/progress/SpecForge/benchmarks/README.md @@ -0,0 +1,67 @@ +# Benchmarking for Speculative Decoding + +## Overview + +We provided a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks. + +## Run Benchmarks + +### Launch SGLang and Benchmarker Concurrently + +`bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are: +- `--model-path`: the path to the target model. +- `--speculative-draft-model-path`: the path to the draft model. +- `--port`: the port to launch the SGLang server. +- `--trust-remote-code`: trust the remote code. +- `--mem-fraction-static`: the memory fraction for the static memory. +- `--tp-size`: the tensor parallelism size. +- `--attention-backend`: the attention backend. +- `--config-list`: the list of speculative decoding configuration to test, the format is `,,,`. +- `--benchmark-list`: the list of benchmarks to test, the format is `::`. + +```shell +python3 bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --port 30000 \ + --trust-remote-code \ + --mem-fraction-static 0.8 \ + --tp-size 1 \ + --attention-backend fa3 \ + --config-list 1,0,0,0 1,3,1,4 \ + --benchmark-list mtbench gsm8k:5 ceval:5:accountant \ + --dtype bfloat16 +``` + +### Launch Benchmarker Independently + +If you want to launch the SGLang server independently, you can use the following command. + +```shell +# you can launch a server +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 1 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 +``` + +Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server. + +```bash +python bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --port 30000 \ + --config-list 1,3,1,4 \ + --benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \ + --skip-launch-server +``` diff --git a/progress/SpecForge/benchmarks/__init__.py b/progress/SpecForge/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfec7eeb81271e0d07feafe767881dba5dcf4acd --- /dev/null +++ b/progress/SpecForge/benchmarks/__init__.py @@ -0,0 +1,3 @@ +""" +Benchmark scripts for speculative decoding evaluation. +""" diff --git a/progress/SpecForge/benchmarks/bench_eagle3.py b/progress/SpecForge/benchmarks/bench_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..988e108f5e1f8ce82e9ccbeaf1b77a5d741fa816 --- /dev/null +++ b/progress/SpecForge/benchmarks/bench_eagle3.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +""" +Usage: + +# if you want to run benchmarks directly +# mtbench:20 means only run 20 samples in the dataset +python bench_eagle3.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --port 30000 \ + --config-list 1,0,0,0 1,3,1,4 \ + --benchmark-list mtbench:20 \ + --dtype bfloat16 + + +or if you want run sglang alone. + +# launch sglang +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 1 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 + +# then run benchmarks +python bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --port 30000 \ + --config-list 1,0,0,0 \ + --benchmark-list mtbench:80 \ + --dtype bfloat16 \ + --skip-launch-server +""" +import argparse +import json +import os +import time +from dataclasses import asdict +from typing import List + +import requests +from benchmarker import BENCHMARKS +from sglang.srt.server_args import ServerArgs +from sglang.test.test_utils import kill_process_tree, popen_launch_server +from sglang.utils import wait_for_server + + +def parse_args(): + parser = argparse.ArgumentParser() + sglang_group = parser.add_argument_group("sglang") + ServerArgs.add_cli_args(sglang_group) + + # make the follow args a group + benchmark_group = parser.add_argument_group("benchmark") + benchmark_group.add_argument( + "--skip-launch-server", action="store_true", default=False + ) + benchmark_group.add_argument("--timeout-for-server-launch", type=int, default=600) + benchmark_group.add_argument("--num-prompts", type=int, default=80) + benchmark_group.add_argument("--output-dir", type=str, default="./results") + benchmark_group.add_argument( + "--config-list", type=str, nargs="+", default=["1,0,0,0", "1,3,1,4"] + ) + benchmark_group.add_argument( + "--name", + type=str, + default=None, + help="name of this benchmark run, if provided, will be added to the output file name", + ) + benchmark_group.add_argument( + "--benchmark-list", + type=str, + nargs="+", + default=[ + "mtbench:80", + "gsm8k:200", + "humaneval:200", + "math500:200", + "ceval:200", + ], + help=f"The list of benchmarks to run. The format is ::,. We support the following benchmarks: {', '.join(BENCHMARKS.benchmarks.keys())}", + ) + benchmark_group.add_argument( + "--enable-multi-turn-conversation", + action="store_true", + default=False, + ) + return parser.parse_args() + + +def launch_sglang_server( + server_args: ServerArgs, + base_url: str, + batch_size: int, + steps: int, + topk: int, + num_draft_tokens: int, + timeout: int, +): + """ + This function launches the SGLang server with the given server arguments. + """ + sglang_args: List[str] = [] + if steps > 0: + sglang_args.extend( + [ + "--speculative-algorithm", + "EAGLE3", + "--speculative-num-steps", + str(steps), + "--speculative-eagle-topk", + str(topk), + "--speculative-num-draft-tokens", + str(num_draft_tokens), + "--speculative-draft-model-path", + server_args.speculative_draft_model_path, + ] + ) + + sglang_args.extend( + [ + "--cuda-graph-max-bs", + str(batch_size), + "--mem-fraction-static", + str(server_args.mem_fraction_static), + "--tp-size", + str(server_args.tp_size), + "--max-running-requests", + str(batch_size), + ] + ) + + if server_args.trust_remote_code: + sglang_args.extend(["--trust-remote-code"]) + + if server_args.disable_radix_cache: + sglang_args.extend(["--disable-radix-cache"]) + + if server_args.ep_size: + sglang_args.extend(["--ep-size", str(server_args.ep_size)]) + + if server_args.attention_backend: + sglang_args.extend(["--attention-backend", server_args.attention_backend]) + + if server_args.quantization: + sglang_args.extend(["--quantization", server_args.quantization]) + + if server_args.dtype: + sglang_args.extend(["--dtype", server_args.dtype]) + + process = popen_launch_server( + server_args.model_path, + base_url, + timeout=timeout, + other_args=sglang_args, + env={ + "SGLANG_RECORD_STEP_TIME": "1", + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1", + **os.environ, + }, + ) + return process + + +def send_flush_cache_request(base_url: str): + requests.post(base_url + "/flush_cache") + + +def main(): + args = parse_args() + server_args: ServerArgs = ServerArgs.from_cli_args(args) + configs = [tuple(map(int, config.split(","))) for config in args.config_list] + + # split the arg into list of (bench_name, num_prompts) + benchmark_list = [] + for item in args.benchmark_list: + splits = item.split(":") + if len(splits) == 1: + bench_name = splits[0] + num_prompts = None + subset = None + elif len(splits) == 2: + bench_name, num_prompts = splits + subset = None + elif len(splits) == 3: + bench_name, num_prompts, subset = splits + subset = subset.split(",") + else: + raise ValueError(f"Invalid benchmark list format: {item}") + benchmark_list.append((bench_name, num_prompts, subset)) + assert len(benchmark_list) != 0, "the number of benchmark list is 0" + + base_url = f"http://localhost:{args.port}" + + results = {} + results["model"] = server_args.speculative_draft_model_path + + def run_benchmarks(batch_size: int, steps: int, topk: int, num_draft_tokens: int): + for benchmark_name, num_prompts, subset in benchmark_list: + print( + f"Running benchmark {benchmark_name} with {num_prompts} prompts, batch size {batch_size}, steps {steps}, topk {topk}, num_draft_tokens {num_draft_tokens}, subset {subset}" + ) + benchmarkder_cls = BENCHMARKS.get(benchmark_name) + num_prompts = int(num_prompts) if num_prompts is not None else None + if subset is None: + benchmarker = benchmarkder_cls(num_samples=num_prompts) + else: + benchmarker = benchmarkder_cls(num_samples=num_prompts, subset=subset) + metrics_list = benchmarker.run( + host=args.host, port=args.port, batch_size=batch_size + ) + send_flush_cache_request(base_url) + if benchmark_name not in results: + results[benchmark_name] = [] + results[benchmark_name].append( + dict( + batch_size=batch_size, + steps=steps, + topk=topk, + num_draft_tokens=num_draft_tokens, + metrics=[asdict(metric) for metric in metrics_list], + num_samples=num_prompts, + ) + ) + + if args.skip_launch_server: + batch_size = configs[0][0] if len(configs) > 0 else 8 + run_benchmarks(batch_size, None, None, None) + else: + # we itearate over each config from args + for batch_size, steps, topk, num_draft_tokens in configs: + process = launch_sglang_server( + server_args, + base_url, + batch_size, + steps, + topk, + num_draft_tokens, + args.timeout_for_server_launch, + ) + wait_for_server(base_url) + run_benchmarks(batch_size, steps, topk, num_draft_tokens) + kill_process_tree(process.pid) + process.wait() + + os.makedirs(args.output_dir, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + result_file = os.path.join( + args.output_dir, + f"{args.name + '_' if args.name else ''}results_{timestamp}.jsonl", + ) + with open(result_file, "w") as f: + json.dump(results, f, indent=4) + print(f"Results saved to {result_file}") + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/benchmarks/benchmarker/__init__.py b/progress/SpecForge/benchmarks/benchmarker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e37fb99c2caa75be8ab58dc51f393f9a748c2b7 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/__init__.py @@ -0,0 +1,29 @@ +from .aime import AIMEBenchmarker +from .ceval import CEvalBenchmarker +from .financeqa import FinanceQABenchmarker +from .gpqa import GPQABenchmarker +from .gsm8k import GSM8KBenchmarker +from .humaneval import HumanEvalBenchmarker +from .livecodebench import LCBBenchmarker +from .math500 import Math500Benchmarker +from .mmlu import MMLUBenchmarker +from .mmstar import MMStarBenchmarker +from .mtbench import MTBenchBenchmarker +from .registry import BENCHMARKS +from .simpleqa import SimpleQABenchmarker + +__all__ = [ + "BENCHMARKS", + "AIMEBenchmarker", + "CEvalBenchmarker", + "GSM8KBenchmarker", + "HumanEvalBenchmarker", + "Math500Benchmarker", + "MTBenchBenchmarker", + "MMStarBenchmarker", + "GPQABenchmarker", + "FinanceQABenchmarker", + "MMLUBenchmarker", + "LCBBenchmarker", + "SimpleQABenchmarker", +] diff --git a/progress/SpecForge/benchmarks/benchmarker/aime.py b/progress/SpecForge/benchmarks/benchmarker/aime.py new file mode 100644 index 0000000000000000000000000000000000000000..fba473c2c6d2f27397f67b424423271288cd6ae7 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/aime.py @@ -0,0 +1,133 @@ +""" +AIME benchmark +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_aime_answer(output: str) -> Optional[str]: + """Extract final answer from AIME problem solution. + + AIME answers are typically integers between 0 and 999, and are usually + in \boxed{} format. + """ + # Try to find answer in \boxed{} format + boxed_pattern = r"\\boxed\{([^}]+)\}" + match = re.search(boxed_pattern, output) + if match: + answer = match.group(1).strip() + # Extract number from the boxed content + numbers = re.findall(r"\d+", answer) + if numbers: + return numbers[-1] # Take the last number (usually the final answer) + return answer + + # Try to find answer in \boxed format (without braces) + boxed_pattern2 = r"\\boxed\s+(\d+)" + match = re.search(boxed_pattern2, output) + if match: + return match.group(1).strip() + + # Look for patterns like "The answer is 42" or "Answer: 123" + answer_patterns = [ + r"(?:answer|Answer|ANSWER)[\s:]+(\d+)", + r"(?:final\s+answer|Final\s+Answer)[\s:]+(\d+)", + r"(?:is|equals?|=\s*)(\d+)\s*$", + ] + for pattern in answer_patterns: + matches = re.findall(pattern, output, re.IGNORECASE) + if matches: + return matches[-1].strip() + + # Fallback: extract the last integer in the text + numbers = re.findall(r"\b(\d+)\b", output) + if numbers: + # Filter to reasonable AIME answer range (0-999) + valid_numbers = [n for n in numbers if 0 <= int(n) <= 999] + if valid_numbers: + return valid_numbers[-1] + + return None + + +@BENCHMARKS.register("aime") +class AIMEBenchmarker(Benchmarker): + """AIME benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess AIME dataset.""" + dataset = load_dataset("Maxwell-Jia/AIME_2024")["train"] + questions = [] + labels = [] + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["Problem"]}) + # Extract answer from Answer field + answer = None + if "Answer" in q: + answer = str(q["Answer"]).strip() + elif "answer" in q: + answer = str(q["answer"]).strip() + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + return extract_aime_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for AIME by comparing numeric answers.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize answers for comparison + pred_normalized = str(pred).strip() + label_normalized = str(label).strip() + # Try exact match first + if pred_normalized == label_normalized: + correct += 1 + else: + # Try numeric comparison + try: + pred_num = int(pred_normalized) + label_num = int(label_normalized) + if pred_num == label_num: + correct += 1 + except ValueError: + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for AIME with reasoning prompt.""" + return create_simple_sgl_function( + function_name="reasoning_gen", + answer_key="answer", + user_prefix="\nPlease reason step by step, and put your final answer within \\boxed{}.", + ) + + def get_max_new_tokens(self) -> int: + """AIME problems require more tokens.""" + return 32768 diff --git a/progress/SpecForge/benchmarks/benchmarker/base.py b/progress/SpecForge/benchmarks/benchmarker/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f8da625319cc854688521b8d9bf1a4b98ac5006b --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/base.py @@ -0,0 +1,218 @@ +""" +Base class for benchmark implementations. +""" + +import time +from abc import ABC, abstractmethod +from argparse import Namespace +from typing import Any, Callable, Dict, List, Optional, Tuple + +from sglang import set_default_backend +from sglang.test.test_utils import select_sglang_backend + +from .utils import compute_metrics + + +class Benchmarker(ABC): + """ + Base class for benchmark implementations. + + Subclasses should implement: + - load_data(): Load and preprocess dataset + - create_sgl_function(): Create the SGL function for inference + + Optional overrides: + - extract_answer(): Extract answer from model output (if needed) + - compute_accuracy(): Compute accuracy metric (if applicable) + - get_answer_keys(): Get list of answer keys for multi-turn conversations + + Args: + num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used. + subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used. + """ + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + self.num_samples = num_samples + self.subset = subset + + @abstractmethod + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]: + """ + Load and preprocess the dataset. + + Returns: + Tuple of (questions, labels) where: + - questions: List of question dicts for SGL function + - labels: List of ground truth labels (can be None if not applicable) + """ + raise NotImplementedError + + @abstractmethod + def create_sgl_function(self) -> Callable: + """ + Create the SGL function for inference. + + Returns: + SGL function decorated with @sgl.function + """ + raise NotImplementedError + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]: + """ + Extract answer from model output. + + Args: + output: Raw model output string + label: Optional ground truth label for reference + + Returns: + Extracted answer, or None if extraction fails + """ + return output + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """ + Compute accuracy metric. + + Args: + predictions: List of predicted answers + labels: List of ground truth labels + + Returns: + Accuracy score (0-1), or None if not applicable + """ + return None + + def get_answer_keys(self) -> Optional[List[str]]: + """ + Get list of answer keys for multi-turn conversations. + + Returns: + List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn + """ + return None + + def get_max_new_tokens(self) -> int: + """ + Get maximum number of new tokens to generate. + + Returns: + Maximum tokens (default: 2048) + """ + return 2048 + + def run( + self, + host: str, + port: int, + batch_size: int, + max_new_tokens: int = None, + num_runs: int = 1, + ): + """ + Run the benchmark evaluation. + + This method handles the common workflow: + 1. Initialize backend + 2. Load data + 3. Create SGL function + 4. Run inference loops + 5. Compute metrics + 6. Print results + + Args: + host (str): The host of the SGLang server + port (int): The port of the SGLang server + batch_size (int): The number of prompts to process in parallel + num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used. + max_new_tokens (int): Maximum number of new tokens to generate, default is 2048 + num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results. + """ + if not host.startswith(("http://", "https://")): + host = f"http://{host}" + # Initialize backend + sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel") + set_default_backend(select_sglang_backend(sglang_args)) + + # Load data + questions, labels = self.load_data() + if len(questions) == 0: + print("No valid questions found. Please check the dataset format.") + return + + # Create SGL function + sgl_function = self.create_sgl_function() + + # Run evaluation loops + metrics_list = [] + answer_keys = self.get_answer_keys() + max_new_tokens = max_new_tokens or self.get_max_new_tokens() + + for _ in range(num_runs): + tic = time.perf_counter() + states = sgl_function.run_batch( + questions, + temperature=0, + max_new_tokens=max_new_tokens, + num_threads=batch_size, + progress_bar=True, + ) + latency = time.perf_counter() - tic + + # Extract predictions + predictions = [] + primary_answer_key = answer_keys[0] if answer_keys else "answer" + for i in range(len(states)): + # Access answer from state object (states[i] supports dict-like access) + output = states[i][primary_answer_key] + if isinstance(output, str): + extracted = self.extract_answer( + output, + (labels[i] if labels and i < len(labels) else None), + ) + else: + extracted = output + predictions.append(extracted) + + # Compute accuracy if applicable + accuracy = None + # Check if we have a labels list (even if all labels are None) + has_labels_list = labels and len(labels) > 0 + + if has_labels_list: + # Always call compute_accuracy if we have a labels list + # This allows it to return None, which will be displayed in print_results + accuracy = self.compute_accuracy(predictions, labels) + if accuracy is not None: + valid_count = sum(1 for p in predictions if p is not None) + if valid_count < len(predictions): + print( + f"Warning: {len(predictions) - valid_count} predictions could not be extracted." + ) + + # Compute performance metrics + metrics = compute_metrics( + states, + latency, + answer_key=primary_answer_key, + additional_answer_keys=( + answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None + ), + ) + # Always set accuracy if we have a labels list (even if compute_accuracy returns None) + # This allows print_results to show None when compute_accuracy returns None + if has_labels_list: + metrics.accuracy = ( + accuracy # Can be None if compute_accuracy returns None + ) + if accuracy is not None: + metrics.num_valid_predictions = sum( + 1 for p in predictions if p is not None + ) + + metrics_list.append(metrics) + return metrics_list diff --git a/progress/SpecForge/benchmarks/benchmarker/ceval.py b/progress/SpecForge/benchmarks/benchmarker/ceval.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b77ccbdb0deb5ce4d2c4522a157836cf0e6efb --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/ceval.py @@ -0,0 +1,267 @@ +""" +C-Eval benchmark evaluation script. +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import concatenate_datasets, load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_answer(answer_str: str) -> str: + """Extract the answer choice (A, B, C, D) from the model output.""" + # Try to find the answer in various formats + answer_str = answer_str.strip().upper() + + # Direct match for single letter + match = re.search(r"\b([ABCD])\b", answer_str) + if match: + return match.group(1) + + # Try to find answer in parentheses or brackets + for pattern in [ + r"\(([ABCD])\)", + r"\[([ABCD])\]", + r"答案[::]\s*([ABCD])", + r"Answer[::]\s*([ABCD])", + ]: + match = re.search(pattern, answer_str, re.IGNORECASE) + if match: + return match.group(1).upper() + + # Try to find the first occurrence of A, B, C, or D + match = re.search(r"([ABCD])", answer_str) + if match: + return match.group(1) + + return None + + +def format_question(question: str, options: List[str]) -> str: + """Format the question with options.""" + prompt = question + "\n\n选项:\n" + for i, option in enumerate(options): + prompt += f"{chr(65 + i)}. {option}\n" + prompt += "\n请从A、B、C、D中选择一个答案。" + return prompt + + +@BENCHMARKS.register("ceval") +class CEvalBenchmarker(Benchmarker): + """C-Eval benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + if subset is None: + subset = "all" + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[str]]: + """Load and preprocess C-Eval dataset.""" + all_configs = [ + "accountant", + "advanced_mathematics", + "art_studies", + "basic_medicine", + "business_administration", + "chinese_language_and_literature", + "civil_servant", + "clinical_medicine", + "college_chemistry", + "college_economics", + "college_physics", + "college_programming", + "computer_architecture", + "computer_network", + "discrete_mathematics", + "education_science", + "electrical_engineer", + "environmental_impact_assessment_engineer", + "fire_engineer", + "high_school_biology", + "high_school_chemistry", + "high_school_chinese", + "high_school_geography", + "high_school_history", + "high_school_mathematics", + "high_school_physics", + "high_school_politics", + "ideological_and_moral_cultivation", + "law", + "legal_professional", + "logic", + "mao_zedong_thought", + "marxism", + "metrology_engineer", + "middle_school_biology", + "middle_school_chemistry", + "middle_school_geography", + "middle_school_history", + "middle_school_mathematics", + "middle_school_physics", + "middle_school_politics", + "modern_chinese_history", + "operating_system", + "physician", + "plant_protection", + "probability_and_statistics", + "professional_tour_guide", + "sports_science", + "tax_accountant", + "teacher_qualification", + "urban_and_rural_planner", + "veterinary_medicine", + ] + + # Select configs to load + if self.subset == "all": + configs_to_load = all_configs + else: + for subset in self.subset: + assert ( + subset in all_configs + ), f"Subset {subset} not found in C-Eval dataset" + configs_to_load = self.subset + + # Load datasets + try: + datasets = [] + for config in configs_to_load: + try: + ds = load_dataset("ceval/ceval-exam", name=config, split="test") + datasets.append(ds) + print(f"Loaded config '{config}' with {len(ds)} samples") + except Exception as e: + print(f"Warning: Failed to load config '{config}': {e}") + if len(datasets) == 0: + raise ValueError("No configs could be loaded") + dataset = concatenate_datasets(datasets) + print( + f"Successfully loaded C-Eval dataset with all configs (total: {len(dataset)} samples)" + ) + except Exception as e: + print(e) + print(f"Failed to load C-Eval dataset from 'ceval/ceval-exam': {e}") + print("Please ensure the dataset is available or install it manually.") + print("You can try: pip install datasets") + print("Or download from: https://huggingface.co/datasets/ceval/ceval-exam") + return [], [] + + # Process questions + questions = [] + labels = [] + for idx, item in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + # Handle different dataset formats + question_text = None + if "question" in item: + question_text = item["question"] + elif "inputs" in item: + question_text = item["inputs"] + elif "problem" in item: + question_text = item["problem"] + elif "content" in item: + question_text = item["content"] + + if not question_text: + continue + + # Get options - C-Eval typically has options as a list or dict + options = None + if "options" in item: + options = item["options"] + if isinstance(options, dict): + # Convert dict to list in order A, B, C, D + options = [ + options.get("A", ""), + options.get("B", ""), + options.get("C", ""), + options.get("D", ""), + ] + elif isinstance(options, list): + # Ensure we have 4 options + while len(options) < 4: + options.append("") + elif "choices" in item: + options = item["choices"] + if isinstance(options, dict): + options = [ + options.get("A", ""), + options.get("B", ""), + options.get("C", ""), + options.get("D", ""), + ] + else: + # Try to construct options from A, B, C, D fields + options = [ + item.get("A", item.get("option_A", "")), + item.get("B", item.get("option_B", "")), + item.get("C", item.get("option_C", "")), + item.get("D", item.get("option_D", "")), + ] + + # Filter out empty options + if options: + options = [str(opt).strip() for opt in options if opt] + if len(options) < 2: # Need at least 2 options + continue + else: + continue + + # Get answer + answer = None + if "answer" in item: + answer = str(item["answer"]).upper().strip() + elif "target" in item: + answer = str(item["target"]).upper().strip() + elif "label" in item: + answer = str(item["label"]).upper().strip() + elif "correct" in item: + answer = str(item["correct"]).upper().strip() + + # Validate answer + if answer and answer in ["A", "B", "C", "D"]: + # Format question + formatted_question = format_question(question_text, options) + questions.append({"question": formatted_question}) + labels.append(answer) + + if len(questions) == 0: + print("No valid questions found. Please check the dataset format.") + print( + "Sample item keys:", + list(dataset[0].keys()) if len(dataset) > 0 else "No items", + ) + return [], [] + + return questions, labels + + def create_sgl_function(self): + """Create SGL function for C-Eval.""" + return create_simple_sgl_function( + function_name="get_ceval_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def extract_answer(self, output: str, label: Any = None) -> str: + """Extract answer choice from model output.""" + return extract_answer(output) + + def compute_accuracy(self, predictions: List[str], labels: List[str]) -> float: + """Compute accuracy metric.""" + correct = 0 + valid_count = 0 + for i in range(len(predictions)): + if predictions[i] is not None: # Only count valid predictions + valid_count += 1 + if predictions[i] == labels[i]: + correct += 1 + return correct / valid_count if valid_count > 0 else 0.0 diff --git a/progress/SpecForge/benchmarks/benchmarker/financeqa.py b/progress/SpecForge/benchmarks/benchmarker/financeqa.py new file mode 100644 index 0000000000000000000000000000000000000000..9323b63423ba288edc79d2ecfb6a33d0a926af7c --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/financeqa.py @@ -0,0 +1,59 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +QUESTION_PROMPT = """ +Given the following context: + +{context} + +Can you answer the following question? + +{question} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + if row["context"] is None: + return row["question"].strip() + else: + question = QUESTION_PROMPT.format( + context=row["context"].strip(), + question=row["question"].strip(), + ) + return question + + +@BENCHMARKS.register("financeqa") +class FinanceQABenchmarker(Benchmarker): + """FinanceQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess FinanceQA dataset.""" + # Read data + ds = load_dataset("AfterQuery/FinanceQA")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_financeqa_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/gpqa.py b/progress/SpecForge/benchmarks/benchmarker/gpqa.py new file mode 100644 index 0000000000000000000000000000000000000000..e2add8fa835a076e51be350c9d95295e0f20bb31 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/gpqa.py @@ -0,0 +1,85 @@ +import random +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +GPQA_QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + gold_index = random.randint(0, 3) + choices = [ + row["Incorrect Answer 1"], + row["Incorrect Answer 2"], + row["Incorrect Answer 3"], + ] + choices.insert(gold_index, row["Correct Answer"]) + + question = GPQA_QUERY_TEMPLATE.format( + Question=row["Question"].strip(), + A=choices[0].strip(), + B=choices[1].strip(), + C=choices[2].strip(), + D=choices[3].strip(), + ) + + # 0 means A, 1 means B, 2 means C, 3 means D + answer = ["A", "B", "C", "D"][gold_index] + return question, answer + + +@BENCHMARKS.register("gpqa") +class GPQABenchmarker(Benchmarker): + """GPQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess GPQA dataset.""" + # Read data + ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text, answer = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + if "Answer: " not in output: + return None + return output.split("Answer: ")[1].strip() + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_gpqa_mcq_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/gsm8k.py b/progress/SpecForge/benchmarks/benchmarker/gsm8k.py new file mode 100644 index 0000000000000000000000000000000000000000..10f8dbae82381cb1cc1a9e7ade454ea58a9da6c7 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/gsm8k.py @@ -0,0 +1,99 @@ +""" +GSM8K benchmark evaluation script. +""" + +import ast +import re +from typing import Any, Dict, List, Optional, Tuple + +from sglang.utils import download_and_cache_file, read_jsonl + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_few_shot_sgl_function + +INVALID = -9999999 + + +def get_one_example(lines: List[Dict], i: int, include_answer: bool) -> str: + """Format a single example.""" + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines: List[Dict], k: int) -> str: + """Get few-shot examples as a string.""" + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str: str) -> int: + """Extract numeric answer from model output.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +@BENCHMARKS.register("gsm8k") +class GSM8KBenchmarker(Benchmarker): + """GSM8K benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + """Load and preprocess GSM8K dataset.""" + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) + + # Construct prompts + few_shot_examples = get_few_shot_examples(lines, 5) + + questions = [] + labels = [] + for i in range((len(lines))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = get_one_example(lines, i, False) + questions.append({"question": question_text}) + labels.append(get_answer_value(lines[i]["answer"])) + + # Store few_shot_examples for use in create_sgl_function + self.few_shot_examples = few_shot_examples + + assert all(l != INVALID for l in labels), "Some labels are invalid" + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + """Extract numeric answer from model output.""" + return get_answer_value(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for GSM8K by comparing numeric answers.""" + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for GSM8K with few-shot examples.""" + return create_few_shot_sgl_function( + few_shot_examples=self.few_shot_examples, + function_name="few_shot_gsm8k", + answer_key="answer", + stop=["Question", "Assistant:", "<|separator|>"], + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/humaneval.py b/progress/SpecForge/benchmarks/benchmarker/humaneval.py new file mode 100644 index 0000000000000000000000000000000000000000..6be1bdec5f6f421ff5787c02d0b0ace9375ebb0f --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/humaneval.py @@ -0,0 +1,188 @@ +""" +HumanEval benchmark evaluation script. +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_code_from_output(output: str) -> Optional[str]: + """Extract Python code from model output. + + Tries to extract code blocks or function definitions. + """ + # Try to find code in markdown code blocks + code_block_pattern = r"```(?:python)?\n(.*?)```" + match = re.search(code_block_pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + # Try to find function definition (common in HumanEval) + # Look for "def " followed by code until the next def or end of string + def_pattern = r"(def\s+\w+\([^)]*\):.*?)(?=\n\ndef\s+|\Z)" + match = re.search(def_pattern, output, re.DOTALL) + if match: + return match.group(1).strip() + + # Fallback: return the output as-is (might already be code) + return output.strip() if output.strip() else None + + +def check_code_passes_tests(code: str, test_code: str, entry_point: str) -> bool: + """Check if generated code passes the test cases. + + This is a simplified version. For full evaluation, use the official + HumanEval evaluation framework. + + HumanEval test code typically contains assertions that will raise + AssertionError if the code doesn't pass. If execution completes without + exceptions, the tests pass. + """ + try: + # Create a safe execution environment + namespace = {} + # Execute the code (function definition) + exec(code, namespace) + # Execute the test code (which contains assertions) + # If no exception is raised, the tests pass + exec(test_code, namespace) + return True + except AssertionError: + # Assertion failed - test didn't pass + return False + except Exception: + # Any other exception (syntax error, runtime error, etc.) means test failed + return False + + +@BENCHMARKS.register("humaneval") +class HumanEvalBenchmarker(Benchmarker): + """HumanEval benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + """Initialize benchmark and store test cases.""" + super().__init__(num_samples, None) + self.test_cases = [] + self.entry_points = [] + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[Dict[str, str]]]]: + """Load and preprocess HumanEval dataset.""" + dataset = load_dataset("openai/openai_humaneval")["test"] + questions = [] + labels = [] + self.test_cases = [] + self.entry_points = [] + + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["prompt"]}) + + # Store test case and entry point for evaluation + test_code = q.get("test", "") + entry_point = q.get("entry_point", "") + self.test_cases.append(test_code) + self.entry_points.append(entry_point) + + # Store canonical solution as reference (optional, for comparison) + canonical_solution = q.get("canonical_solution", "") + labels.append( + { + "test": test_code, + "entry_point": entry_point, + "canonical_solution": canonical_solution, + } + ) + + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract code from model output.""" + return extract_code_from_output(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for HumanEval by checking if code passes tests. + + Note: This is a simplified evaluation. For official pass@k metrics, + use the HumanEval evaluation framework. + """ + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + + for i, (pred, label) in enumerate(zip(predictions, labels)): + if label is not None and isinstance(label, dict): + valid_count += 1 + if pred is not None: + try: + # Get the prompt (function signature and docstring) + prompt = self.questions[i]["question"] + entry_point = label.get("entry_point", "") + + # The prompt contains the function signature (e.g., "def function_name(...):") + # The generated code might be: + # 1. Just the function body (what we want) - need to combine with prompt + # 2. The complete function including signature - use as-is + # 3. Code in markdown blocks - already extracted by extract_code_from_output + + pred_str = str(pred).strip() + + # Check if pred already contains a complete function definition + # (starts with "def " and contains the entry_point function name) + if pred_str.startswith("def ") and entry_point: + # Check if this is the same function (by name) + func_name_match = re.match(r"def\s+(\w+)\s*\(", pred_str) + if ( + func_name_match + and func_name_match.group(1) == entry_point + ): + # Generated code includes complete function, use it as-is + full_code = pred_str + else: + # Different function or no match, combine with prompt + full_code = prompt + "\n" + pred_str + elif pred_str.startswith("def "): + # Has function definition but we can't verify entry_point, use as-is + full_code = pred_str + else: + # Generated code is just the body, combine with prompt + full_code = prompt + "\n" + pred_str + + # Check if code passes tests + test_code = label.get("test", "") + + if test_code and check_code_passes_tests( + full_code, test_code, entry_point + ): + correct += 1 + except Exception as e: + # If evaluation fails, consider it incorrect + # Uncomment for debugging: print(f"Error evaluating code {i}: {e}") + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for HumanEval.""" + return create_simple_sgl_function( + function_name="get_humaneval_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def get_max_new_tokens(self) -> int: + """HumanEval code generation requires more tokens.""" + return 1024 diff --git a/progress/SpecForge/benchmarks/benchmarker/livecodebench.py b/progress/SpecForge/benchmarks/benchmarker/livecodebench.py new file mode 100644 index 0000000000000000000000000000000000000000..490ba2b20349ecd68a3edc468d38ef377c6e8d05 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/livecodebench.py @@ -0,0 +1,46 @@ +""" +GSM8K benchmark evaluation script. +""" + +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def generate_question(row: Dict[str, Any]) -> str: + question = row["question_content"].strip() + return question + + +@BENCHMARKS.register("livecodebench") +class LCBBenchmarker(Benchmarker): + """LiveCodeBench benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + ds = load_dataset("livecodebench/code_generation")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_livecodebench_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/math500.py b/progress/SpecForge/benchmarks/benchmarker/math500.py new file mode 100644 index 0000000000000000000000000000000000000000..64ca48eb386aa6f388ef997c34de496dad4db1b7 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/math500.py @@ -0,0 +1,122 @@ +""" +MATH-500 benchmark evaluation script. +""" + +import re +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def extract_math_answer(output: str) -> Optional[str]: + """Extract final answer from math problem solution. + + Tries to extract answer from \boxed{} format first, then looks for + the last number in the output. + """ + # Try to find answer in \boxed{} format + boxed_pattern = r"\\boxed\{([^}]+)\}" + match = re.search(boxed_pattern, output) + if match: + return match.group(1).strip() + + # Try to find answer in \boxed format (without braces) + boxed_pattern2 = r"\\boxed\s+([^\s]+)" + match = re.search(boxed_pattern2, output) + if match: + return match.group(1).strip() + + # Try to find the last number (could be integer or decimal) + # Look for patterns like "The answer is 42" or "Answer: 3.14" + answer_patterns = [ + r"(?:answer|Answer|ANSWER)[\s:]+([-+]?\d*\.?\d+)", + r"(?:is|equals?|=\s*)([-+]?\d*\.?\d+)\s*$", + ] + for pattern in answer_patterns: + matches = re.findall(pattern, output, re.IGNORECASE) + if matches: + return matches[-1].strip() + + # Fallback: extract the last number in the text + numbers = re.findall(r"[-+]?\d*\.?\d+", output) + if numbers: + return numbers[-1] + + return None + + +@BENCHMARKS.register("math500") +class Math500Benchmarker(Benchmarker): + """MATH-500 benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess MATH-500 dataset.""" + dataset = load_dataset("HuggingFaceH4/MATH-500")["test"] + questions = [] + labels = [] + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + questions.append({"question": q["problem"]}) + # Extract answer from solution or answer field + answer = None + if "answer" in q: + answer = str(q["answer"]).strip() + elif "solution" in q: + # Try to extract from solution + answer = extract_math_answer(q["solution"]) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + return extract_math_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for MATH-500 by comparing answers.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize answers for comparison (remove whitespace, handle different formats) + pred_normalized = str(pred).strip().lower() + label_normalized = str(label).strip().lower() + # Try exact match first + if pred_normalized == label_normalized: + correct += 1 + else: + # Try numeric comparison if both are numbers + try: + pred_num = float(pred_normalized) + label_num = float(label_normalized) + if abs(pred_num - label_num) < 1e-6: + correct += 1 + except ValueError: + pass + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for MATH-500.""" + return create_simple_sgl_function( + function_name="get_math500_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/mmlu.py b/progress/SpecForge/benchmarks/benchmarker/mmlu.py new file mode 100644 index 0000000000000000000000000000000000000000..407339a82e2f1d86d8829a33ededb2201f3b2ee2 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/mmlu.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + +GPQA_QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + + +def generate_question(row: Dict[str, Any]) -> str: + choices = row["choices"] + question = GPQA_QUERY_TEMPLATE.format( + Question=row["question"].strip(), + A=choices[0].strip(), + B=choices[1].strip(), + C=choices[2].strip(), + D=choices[3].strip(), + ) + + # 0 means A, 1 means B, 2 means C, 3 means D + answer = ["A", "B", "C", "D"][row["answer"]] + print(answer) + return question, answer + + +@BENCHMARKS.register("mmlu") +class MMLUBenchmarker(Benchmarker): + """MMLU benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + if subset is None: + subset = ["all"] + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + questions = [] + labels = [] + + for subset in self.subset: + ds = load_dataset("cais/mmlu", subset)["test"] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text, answer = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(answer) + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: + if "Answer: " not in output: + return None + return output.split("Answer: ")[1].strip() + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + if not labels or len(labels) == 0: + return None + correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) + return correct / len(labels) if len(labels) > 0 else 0.0 + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_mmlu_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/mmstar.py b/progress/SpecForge/benchmarks/benchmarker/mmstar.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab1c44a28023a6bf18277edcacbe96794fa2c6a --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/mmstar.py @@ -0,0 +1,185 @@ +""" +MMStar benchmark evaluation script. +""" + +import os +import re +import shutil +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_image_sgl_function + + +def extract_mmstar_answer( + output: str, options: Optional[List[str]] = None +) -> Optional[str]: + """Extract answer from MMStar model output. + + MMStar questions typically have multiple choice options (A, B, C, D, etc.) + """ + output_upper = output.strip().upper() + + # Try to find answer choice (A, B, C, D, etc.) + # Direct match for single letter + match = re.search(r"\b([A-Z])\b", output_upper) + if match: + letter = match.group(1) + if options and len(options) > 0: + # Validate that the letter is within valid range + max_option = chr(64 + len(options)) # 'A' + (len-1) + if "A" <= letter <= max_option: + return letter + else: + # Assume A-D are valid + if "A" <= letter <= "D": + return letter + + # Try to find answer in parentheses or brackets + for pattern in [ + r"\(([A-Z])\)", + r"\[([A-Z])\]", + r"答案[::]\s*([A-Z])", + r"Answer[::]\s*([A-Z])", + r"选择[::]\s*([A-Z])", + ]: + match = re.search(pattern, output_upper) + if match: + letter = match.group(1) + if options and len(options) > 0: + max_option = chr(64 + len(options)) + if "A" <= letter <= max_option: + return letter + elif "A" <= letter <= "D": + return letter + + return None + + +@BENCHMARKS.register("mmstar") +class MMStarBenchmarker(Benchmarker): + """MMStar benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + """Initialize benchmark and set up cache directory.""" + self.cache_dir = None + self.options_list = [] # Store options for each question + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[Optional[str]]]: + """Load and preprocess MMStar dataset.""" + self.cache_dir = os.path.join(".cache", "mmstar_specforge") + image_dir = os.path.join(self.cache_dir, "images") + os.makedirs(self.cache_dir, exist_ok=True) + os.makedirs(image_dir, exist_ok=True) + print(f"Created temporary image directory: {self.cache_dir}") + + dataset = load_dataset("Lin-Chen/MMStar")["val"] + questions = [] + labels = [] + self.options_list = [] + + for idx, q in enumerate(dataset): + if self.num_samples is not None and idx >= self.num_samples: + break + + image = q["image"] + image_path = os.path.join(self.cache_dir, q["meta_info"]["image_path"]) + image.convert("RGB").save(image_path, "JPEG") + + # Extract question and options + question_full = q["question"] + if "Options:" in question_full: + question_text, options_text = question_full.split("Options:", 1) + question_text = question_text.strip() + # Parse options (typically A. option1 B. option2 etc.) + options = [] + for line in options_text.strip().split("\n"): + line = line.strip() + if line and re.match(r"^[A-Z]\.", line): + option_text = re.sub(r"^[A-Z]\.\s*", "", line).strip() + options.append(option_text) + self.options_list.append(options) + else: + question_text = question_full.strip() + self.options_list.append([]) + + item = { + "image_path": image_path, + "question": question_text, + } + questions.append(item) + + # Extract ground truth answer + answer = None + if "answer" in q: + answer = str(q["answer"]).strip().upper() + elif "correct_answer" in q: + answer = str(q["correct_answer"]).strip().upper() + elif "ground_truth" in q: + answer = str(q["ground_truth"]).strip().upper() + + # Validate answer is a valid option letter + if answer and len(answer) == 1 and "A" <= answer <= "Z": + if self.options_list[-1]: + max_option = chr(64 + len(self.options_list[-1])) + if answer <= max_option: + labels.append(answer) + else: + labels.append(None) + else: + labels.append(answer) + else: + labels.append(None) + + return questions, labels + + def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[str]: + """Extract answer from model output.""" + # Use the options for the current question if available + # Note: We can't easily get the question index here, so we'll use a simpler approach + return extract_mmstar_answer(output) + + def compute_accuracy( + self, predictions: List[Any], labels: List[Any] + ) -> Optional[float]: + """Compute accuracy for MMStar by comparing answer choices.""" + if not labels or len(labels) == 0: + return None + if all(label is None for label in labels): + return None + + correct = 0 + valid_count = 0 + for pred, label in zip(predictions, labels): + if label is not None: + valid_count += 1 + if pred is not None: + # Normalize to uppercase for comparison + pred_normalized = str(pred).strip().upper() + label_normalized = str(label).strip().upper() + if pred_normalized == label_normalized: + correct += 1 + + return correct / valid_count if valid_count > 0 else 0.0 + + def create_sgl_function(self): + """Create SGL function for MMStar (image-based Q&A).""" + return create_image_sgl_function( + function_name="get_mmstar_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) + + def run(self, *args, **kwargs): + """Run benchmark and clean up cache directory.""" + try: + return super().run(*args, **kwargs) + finally: + # Clean up cache directory + if self.cache_dir and os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + print(f"Deleted temporary directory: {self.cache_dir}") diff --git a/progress/SpecForge/benchmarks/benchmarker/mtbench.py b/progress/SpecForge/benchmarks/benchmarker/mtbench.py new file mode 100644 index 0000000000000000000000000000000000000000..46f2d1d611c8065219d65b221c1c15ae9409e21f --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/mtbench.py @@ -0,0 +1,59 @@ +""" +MT-Bench benchmark evaluation script. +Adapted from https://github.com/chromecast56/sglang/blob/6f145d2eadb93a116134f703358ce76f15381045/benchmark/mtbench/bench_sglang.py +""" + +from typing import Any, Dict, List, Optional, Tuple + +from sglang.utils import download_and_cache_file, read_jsonl + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_multi_turn_sgl_function + +SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." + + +@BENCHMARKS.register("mtbench") +class MTBenchBenchmarker(Benchmarker): + """MT-Bench benchmark implementation.""" + + def __init__( + self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None + ): + # support categorical data for mtbench + if subset is None: + subset = ["all"] + super().__init__(num_samples, subset) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[None]]: + """Load and preprocess MT-Bench dataset.""" + url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" + download_and_cache_file(url, filename="mtbench.jsonl") + questions_data = list(read_jsonl("mtbench.jsonl")) + questions_data = questions_data + + questions = [ + {"question_1": q["turns"][0], "question_2": q["turns"][1]} + for q in questions_data + ] + # MT-Bench doesn't have labels for accuracy computation + labels = [None] * len(questions) + + if self.num_samples is not None: + questions = questions[: self.num_samples] + labels = labels[: self.num_samples] + return questions, labels + + def create_sgl_function(self): + """Create SGL function for MT-Bench (2-turn conversation).""" + return create_multi_turn_sgl_function( + function_name="answer_mt_bench", + system_prompt=SYSTEM_PROMPT, + num_turns=2, + max_tokens=self.get_max_new_tokens(), + ) + + def get_answer_keys(self) -> List[str]: + """Return answer keys for multi-turn conversation.""" + return ["answer_1", "answer_2"] diff --git a/progress/SpecForge/benchmarks/benchmarker/registry.py b/progress/SpecForge/benchmarks/benchmarker/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4f474fcd15bd9a891b8f8977465aaa233c9fd1 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/registry.py @@ -0,0 +1,31 @@ +class BenchmarkRegistry: + + def __init__(self): + self.benchmarks = {} + + def register(self, name: str): + """ + Usage: + ```python + BENCHMARKS = BenchmarkRegistry() + + BENCHMARKS.register("aime") + class AIMEBenchmarker(Benchmarker): + ... + ``` + """ + + def wrapper(cls): + self.benchmarks[name] = cls + return cls + + return wrapper + + def get(self, name: str) -> type: + """ + Get the benchmark class by name. + """ + return self.benchmarks[name] + + +BENCHMARKS = BenchmarkRegistry() diff --git a/progress/SpecForge/benchmarks/benchmarker/simpleqa.py b/progress/SpecForge/benchmarks/benchmarker/simpleqa.py new file mode 100644 index 0000000000000000000000000000000000000000..5facab00d719d6d235a8cb50d161679ebe28f6a0 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/simpleqa.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset + +from .base import Benchmarker +from .registry import BENCHMARKS +from .utils import create_simple_sgl_function + + +def generate_question(row: Dict[str, Any]) -> str: + question = row["problem"].strip() + return question + + +@BENCHMARKS.register("simpleqa") +class SimpleQABenchmarker(Benchmarker): + """SimpleQA benchmark implementation.""" + + def __init__(self, num_samples: Optional[int] = None): + super().__init__(num_samples, None) + + def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: + # Read data + ds = load_dataset("basicv8vc/SimpleQA")["test"] + + questions = [] + labels = [] + for i in range((len(ds))): + if self.num_samples is not None and i >= self.num_samples: + break + + question_text = generate_question(ds[i]) + questions.append({"question": question_text}) + labels.append(None) + return questions, labels + + def create_sgl_function(self): + return create_simple_sgl_function( + function_name="get_simpleqa_answer", + answer_key="answer", + max_tokens=self.get_max_new_tokens(), + ) diff --git a/progress/SpecForge/benchmarks/benchmarker/utils.py b/progress/SpecForge/benchmarks/benchmarker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a6dabfb9a4ef1789b7a89a5d7131755a6e6fa8 --- /dev/null +++ b/progress/SpecForge/benchmarks/benchmarker/utils.py @@ -0,0 +1,273 @@ +""" +Utility functions for benchmark scripts. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import sglang as sgl + + +@dataclass +class BenchmarkMetrics: + """Container for benchmark performance metrics.""" + + latency: float + output_throughput: float + accept_length: float + accuracy: Optional[float] = None + num_questions: int = 0 + num_valid_predictions: int = 0 + categorical_performance: Optional[Dict[str, "BenchmarkMetrics"]] = None + + +def compute_metrics( + states: List[Any], + latency: float, + answer_key: str = "answer", + additional_answer_keys: Optional[List[str]] = None, +) -> BenchmarkMetrics: + """ + Compute performance metrics from SGLang states. + + Args: + states: List of SGLang state objects from run_batch + latency: Total latency in seconds + answer_key: Primary key for answer in state meta info + additional_answer_keys: Additional keys to include in token count (e.g., ["answer_1", "answer_2"]) + + Returns: + BenchmarkMetrics object with computed metrics + """ + # Compute output tokens + num_output_tokens = 0 + if additional_answer_keys: + for key in [answer_key] + additional_answer_keys: + num_output_tokens += sum( + s.get_meta_info(key)["completion_tokens"] for s in states + ) + else: + num_output_tokens = sum( + s.get_meta_info(answer_key)["completion_tokens"] for s in states + ) + + output_throughput = num_output_tokens / latency if latency > 0 else 0.0 + + # Compute accept length (speculative decoding metric) + has_verify = "spec_verify_ct" in states[0].get_meta_info(answer_key) + if has_verify: + num_verify_tokens = 0 + if additional_answer_keys: + for key in [answer_key] + additional_answer_keys: + num_verify_tokens += sum( + s.get_meta_info(key).get("spec_verify_ct", 0) for s in states + ) + else: + num_verify_tokens = sum( + s.get_meta_info(answer_key).get("spec_verify_ct", 0) for s in states + ) + + if num_verify_tokens == 0: + accept_length = 1.0 + else: + accept_length = num_output_tokens / num_verify_tokens + else: + accept_length = 1.0 + + return BenchmarkMetrics( + latency=latency, + output_throughput=output_throughput, + accept_length=accept_length, + num_questions=len(states), + ) + + +def print_results( + metrics_list: List[BenchmarkMetrics], + benchmark_name: str, + show_accuracy: bool = False, +): + """ + Print benchmark results in a formatted way. + + Args: + metrics_list: List of BenchmarkMetrics from multiple runs + benchmark_name: Name of the benchmark + show_accuracy: Whether to show accuracy metrics + """ + avg_latency = np.mean([m.latency for m in metrics_list]) + avg_throughput = np.mean([m.output_throughput for m in metrics_list]) + avg_accept_length = np.mean([m.accept_length for m in metrics_list]) + + print(f"\n{'='*50}") + print(f"{benchmark_name} Evaluation Results") + print(f"{'='*50}") + print(f"Number of questions: {metrics_list[0].num_questions}") + if show_accuracy: + if metrics_list[0].accuracy is not None: + avg_accuracy = np.mean( + [m.accuracy for m in metrics_list if m.accuracy is not None] + ) + print(f"Average Accuracy: {avg_accuracy:.4f} ({avg_accuracy*100:.2f}%)") + else: + print(f"Average Accuracy: None") + print(f"Average Latency: {avg_latency:.3f} s") + print(f"Average Output throughput: {avg_throughput:.3f} token/s") + print(f"Average Accept length: {avg_accept_length:.3f}") + print(f"{'='*50}\n") + + +def create_simple_sgl_function( + function_name: str = "get_answer", + answer_key: str = "answer", + system_prompt: Optional[str] = None, + max_tokens: int = 2048, + stop: Optional[List[str]] = None, + user_prefix: Optional[str] = None, +) -> Callable: + """ + Create a simple SGL function for single-turn Q&A. + + Args: + function_name: Name of the function + answer_key: Key for storing the answer + system_prompt: Optional system prompt + max_tokens: Maximum tokens to generate + stop: Optional stop sequences + user_prefix: Optional suffix to append to user message (appended after question) + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, question): + if system_prompt: + s += sgl.system(system_prompt) + user_content = question + if user_prefix: + user_content = question + user_prefix + s += sgl.user(user_content) + gen_kwargs = {"max_tokens": max_tokens} + if stop: + gen_kwargs["stop"] = stop + s += sgl.assistant(sgl.gen(answer_key, **gen_kwargs)) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_few_shot_sgl_function( + few_shot_examples: str, + function_name: str = "few_shot_answer", + answer_key: str = "answer", + max_tokens: int = 512, + stop: Optional[List[str]] = None, +) -> Callable: + """ + Create an SGL function for few-shot learning. + + Args: + few_shot_examples: String containing few-shot examples + function_name: Name of the function + answer_key: Key for storing the answer + max_tokens: Maximum tokens to generate + stop: Optional stop sequences + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, question): + s += few_shot_examples + question + gen_kwargs = {"max_tokens": max_tokens} + if stop: + gen_kwargs["stop"] = stop + s += sgl.gen(answer_key, **gen_kwargs) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_multi_turn_sgl_function( + function_name: str = "multi_turn_answer", + system_prompt: Optional[str] = None, + num_turns: int = 2, + max_tokens: int = 2048, +) -> Callable: + """ + Create an SGL function for multi-turn conversations (e.g., MT-Bench with 2 turns). + + Args: + function_name: Name of the function + system_prompt: Optional system prompt + num_turns: Number of conversation turns (default: 2) + max_tokens: Maximum tokens to generate per turn + + Returns: + SGL function decorated with @sgl.function + """ + if num_turns == 2: + # Most common case: 2-turn conversation + @sgl.function + def sgl_func(s, question_1, question_2): + if system_prompt: + s += sgl.system(system_prompt) + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=max_tokens)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=max_tokens)) + + else: + # Generic case: create function with dynamic number of turns + # Note: This requires the caller to pass arguments as a dict + @sgl.function + def sgl_func(s, **kwargs): + if system_prompt: + s += sgl.system(system_prompt) + for i in range(num_turns): + question_key = f"question_{i+1}" + answer_key = f"answer_{i+1}" + if question_key in kwargs: + s += sgl.user(kwargs[question_key]) + s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens)) + + sgl_func.__name__ = function_name + return sgl_func + + +def create_image_sgl_function( + function_name: str = "get_image_answer", + answer_key: str = "answer", + max_tokens: int = 2048, +) -> Callable: + """ + Create an SGL function for image-based Q&A. + + Args: + function_name: Name of the function + answer_key: Key for storing the answer + max_tokens: Maximum tokens to generate + + Returns: + SGL function decorated with @sgl.function + """ + + @sgl.function + def sgl_func(s, image_path, question, **kwargs): + """ + The body of the SGL function: constructs a multimodal conversation flow. + + - First, it inputs an image + text question as 'user'. + - Then, it generates an answer as 'assistant', binding the response to the specified `answer_key`. + + Note: sgl.image() automatically encodes the image into a format supported by the model for multimodal input. + """ + # User input: Image + Text question + s += sgl.user(sgl.image(image_path) + question) + s += sgl.assistant(sgl.gen(answer_key, max_tokens=max_tokens)) + + sgl_func.__name__ = function_name + return sgl_func diff --git a/progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py b/progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py new file mode 100644 index 0000000000000000000000000000000000000000..1a465443c6703ff930aafb83415e2c0ace7f19a2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/26/c26l7dxpqbfol7d62sqakxdv4rgyh27yhm4hrctevbkw5t6kekia.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py b/progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py new file mode 100644 index 0000000000000000000000000000000000000000..13d410084c73f8592cb946245ce8fc79aaa18d2e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2d/c2d4e47kqxxnp6455gvkteqq3r336462zkbitosyeko6znxktn2b.py @@ -0,0 +1,879 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6g/c6gb52skvqs7or57vd3zu5um3r5rnmeimd5qam27l5j7uqx7t4ai.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream3) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream3) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream3 = get_raw_stream(3) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream3) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 96 + arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + arg2_1 = 96 + arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg4_1 = 96 + arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg7_1 = 96 + arg8_1 = 96 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py b/progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py new file mode 100644 index 0000000000000000000000000000000000000000..8396140d62eff35e58f5c5b4fbbc602fa1b9529b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2g/c2gswut4q57fp2ueybipg5qfqiy5coitofujwdnvqdwhr7nbvnyq.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config b/progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2j/4b74fa21eaaf86b6290185f6fe50aec9b905d858a087238ceddb52477f3f6acb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py b/progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py new file mode 100644 index 0000000000000000000000000000000000000000..555576b65175e401e3412393a13cfa687097afaa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py b/progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py new file mode 100644 index 0000000000000000000000000000000000000000..913928d0621bf1518e6aa3d50541643337978b4a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2n/c2ngvuchx6agpdr6v7awl3qgblaehfzaauoxn6camwvtk7syoxsk.py @@ -0,0 +1,715 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6u/c6uror2yjtc6vpcc3on3oq3lwi6yghlxrmwz5rocw5haxvfiz47e.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream2) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py b/progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3d487187076a492a2867ff38a103a8aed399fead --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2n/c2nooi7ekpz4qvmvghggbegd5cyfspb27jmq2snbi26zbrpoibnx.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config b/progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2n/d17ff4e7bb44e5ae89a267ef332bb7c074804ce0942fc0694c3ef15b05f7854a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py b/progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py new file mode 100644 index 0000000000000000000000000000000000000000..23e99aa77651e7f24c650c4ecc7576c873ad979a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2o/c2oashzxz74kzyuwo67tuhk32cike37ysabriftachdv7lf2qxgs.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py b/progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py new file mode 100644 index 0000000000000000000000000000000000000000..343fb5178bb02b9faa6061b2b03607fd8e1b7ade --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2v/c2vob47d7sxpitzmofyr55f5hvxsitxjhpyv5hdiqcdjgbwmxk76.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py b/progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf65f71c7bdac59005662b2c77f08c968160b3e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2y/c2yhndikcsebqfmbw7l44gmcdoyw7ogaqt7quyeygz3mp5w6u6ke.py @@ -0,0 +1,715 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/qr/cqrng7hmawuvea5b46xnw26e3vaokywqdqnuhn4vt7tmtdoleeab.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py b/progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad5ba9ac7511748dcb63b8ae68c65bec8dc8808 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2z/c2zdv5arszdl6ednyphqfnib6jwgzomr6zt6536b7gq75kp67uvh.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:4" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:4" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:4" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:4" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream4) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 960 + primals_11 = 960 + primals_7 = 8 + primals_8 = 8 + primals_12 = 8 + primals_14 = 8 + primals_16 = 8 + primals_17 = 8 + primals_19 = 8 + primals_22 = 8 + primals_21 = 8 + primals_24 = 8 + primals_27 = 8 + primals_26 = 8 + primals_2 = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 960, 128), (3932160, 122880, 128, 1), device='cuda:4', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 960), (30720, 960, 1), device='cuda:4', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py b/progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py new file mode 100644 index 0000000000000000000000000000000000000000..40d77a62e0225881517920a43a111914451e03e2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/2z/c2zqq6qyjomc7iflknbqr7yjdhjux47hzv4nnsi5qfbeqglaip2h.py @@ -0,0 +1,707 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cf/ccftanvnrini6kruughcnjtpfiarn7zwa2sdotthpo3wbbjituv3.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream0) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config b/progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3ee655a7a0e1446928adcd473d1a1d0657ad7c3b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/32/8d96bbe05a966b7e7756831f09a79e31bf46fad0952af86f36d75557fc1735e8.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 13, "triton_cache_hash": "BGHEC74L2RGBNBI3A4UJOTHXFUUKS4KY3KJKVN65FHLWR47O6USQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py b/progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d905946db57a3229e0b2e07e835478e522b8e2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/32/c32pbcuz72bjfnkzvckfbbzlzuupc5yxl7t47b3qf74mmk5g2d2z.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 282624}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 35328 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config b/progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config new file mode 100644 index 0000000000000000000000000000000000000000..1da043596702f38030a1f9fe4001039fad10fcd8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3b/a0a6b043ab548fdf71e72bbdf5daab7f72e9ed11a9ad9f8824a6263bb6bc5081.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "DSCNRRQHW6TSEFKL6AMK6FYZWMIHBTRCG2BE5YK5T7Q76TMOZ5HQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py b/progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7caca9e71afd9956e3c960c30fdfbd3a07b9d4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 348160}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 43520 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config b/progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d131a13c7786a0afd9ad24a5628260bf7969ba42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3f/3f6057605b157d44fd56f748226a63975b79198f94871188e73e46cd6c7f8792.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py b/progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py new file mode 100644 index 0000000000000000000000000000000000000000..cae642309057dffd9319a5de1fd2bbe0cdcf8e8b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3f/c3fttv7enp2yvnla3r6jkk4galt2qdpxw577ghvkmmx6zqaqla74.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py b/progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py new file mode 100644 index 0000000000000000000000000000000000000000..084acd3310bc4d99b5192a9aefe7d9a21a841f62 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3n/c3nlaqknekmjv2zuxzow4rf42v3gorxnfp6uod3dg3ic5ibp6yp3.py @@ -0,0 +1,715 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:1" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:1" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg9_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg10_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg11_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg12_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg13_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg14_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg15_1, (1, 1, 5, 5), (25, 25, 5, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream1) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 528 + arg1_1 = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = 528 + arg3_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg4_1 = 528 + arg5_1 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32) + arg7_1 = 528 + arg8_1 = 528 + arg9_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:1', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py b/progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2263771ab25159069df58d21510b194324e6fa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3q/c3qbvcsx2w7qss2v3eocuadgz6t35joo33bflzqkxzzj747zcjpk.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config b/progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e3d800553617fd03757f50025ba8220d370ba866 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3q/fc5920467dd1501963c976e2b895fc37747fdebfa098fff912209055f3a31828.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "2685c2d349c32243d4ee216505dfdf1e257d04d8316595ed69d4ca3499146788", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py b/progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py new file mode 100644 index 0000000000000000000000000000000000000000..2068c3a2c7f507d8c01e93a37a0eab0bc14105a6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3r/c3rkwwyedldrjz6sidtx5huqcsdgpdpu4xndmm6h4e4boo6cbg2w.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (976, 976, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 31232 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg4_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg5_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg6_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg7_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg8_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg9_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg10_1, (1, 1, 8, 8), (64, 64, 8, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 8, 1, 32, stream=stream7) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, 31232, stream=stream7) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py b/progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py new file mode 100644 index 0000000000000000000000000000000000000000..17a9b6f51407f52b6590b3db9f783fad5d3f2e66 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py b/progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py new file mode 100644 index 0000000000000000000000000000000000000000..7d43c5894ed5df429e8dd9153f9c6f72f4488a6e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/3z/c3zilfzjywngbdehwphwkhzpt6qcv6jecvzdajl2d5hb73xe6yzw.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config b/progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ba61f75017bba67ac35906919dd2dcfe202f258f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4a/7887d45b1aa6124e232769adbe995f9cc2af0dd187cb9928540172d82c7b8631.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "BZAXIZYYJGUVREZ5ANMEKVK5UU77TPVNED7QAB22EKNJIKFVURYA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4a/c4anvofd7zpauvjftktjqqrtqgoz7sd5ckg3wf7op3c7cmnbrzz6.py b/progress/SpecForge/cache/compiled_kernels/4a/c4anvofd7zpauvjftktjqqrtqgoz7sd5ckg3wf7op3c7cmnbrzz6.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbb72f86f69b10e329c5c146ebe73a515aabe6c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4a/c4anvofd7zpauvjftktjqqrtqgoz7sd5ckg3wf7op3c7cmnbrzz6.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 159744}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 19968 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/4b/c4bs4faiglroap6s5exqskt27kcrendxbin52u3h5veyhjham26s.py b/progress/SpecForge/cache/compiled_kernels/4b/c4bs4faiglroap6s5exqskt27kcrendxbin52u3h5veyhjham26s.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1d6f28f2625d95f634c39906b62493b91fd39a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4b/c4bs4faiglroap6s5exqskt27kcrendxbin52u3h5veyhjham26s.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/4b/d5ba6f7d42b5d9435ddcf38a1bb5c3debbf831fb501a3c2b291cd06d573f3c00.best_config b/progress/SpecForge/cache/compiled_kernels/4b/d5ba6f7d42b5d9435ddcf38a1bb5c3debbf831fb501a3c2b291cd06d573f3c00.best_config new file mode 100644 index 0000000000000000000000000000000000000000..77b32a09d39c99cc42575e211ae8fe06cc0fdf56 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4b/d5ba6f7d42b5d9435ddcf38a1bb5c3debbf831fb501a3c2b291cd06d573f3c00.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 41, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4j/c4jmnaggbi4yx2v2ae5kwf3fogxpa66ljtkcafj6ca4cdsnhrruu.py b/progress/SpecForge/cache/compiled_kernels/4j/c4jmnaggbi4yx2v2ae5kwf3fogxpa66ljtkcafj6ca4cdsnhrruu.py new file mode 100644 index 0000000000000000000000000000000000000000..e0808012e71ed317075c59cdc92ce3e950f21882 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4j/c4jmnaggbi4yx2v2ae5kwf3fogxpa66ljtkcafj6ca4cdsnhrruu.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4q/c4q7bstz2lpt7xdxwimrikdhakr453bc2dshde4ipprtpwyfkemh.py b/progress/SpecForge/cache/compiled_kernels/4q/c4q7bstz2lpt7xdxwimrikdhakr453bc2dshde4ipprtpwyfkemh.py new file mode 100644 index 0000000000000000000000000000000000000000..3770a728219f6c34114d5a411145ebaad9bb25fc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4q/c4q7bstz2lpt7xdxwimrikdhakr453bc2dshde4ipprtpwyfkemh.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4q/c4qr34ksfsbq2fspbuwjj2hhj6zbf7jixsrxjyehxywgmpgt3pab.py b/progress/SpecForge/cache/compiled_kernels/4q/c4qr34ksfsbq2fspbuwjj2hhj6zbf7jixsrxjyehxywgmpgt3pab.py new file mode 100644 index 0000000000000000000000000000000000000000..09c193c53b8a7ee0a7abb9de2c8221e290241703 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4q/c4qr34ksfsbq2fspbuwjj2hhj6zbf7jixsrxjyehxywgmpgt3pab.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1488, 1488, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 190464*idx_hq + 6094848*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zj/czjxl4o5zq2vmca3k2ha3vhclo3rwmhvxomcpzhwqmwuamwvvilm.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 380928}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 47616 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1488, 128), (6094848, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_5, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_6, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_7, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_8, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_9, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_10, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_11, (1, 1, 12, 12), (144, 144, 12, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1488, 128), (6094848, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 12, 1, 32, stream=stream0) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, 47616, stream=stream0) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1488, 128), (6094848, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/4w/a6c1f4a6f24078ac5c758b15e1ac8fb70632f5bd1407a64486e4c64c3bcf9a1a.best_config b/progress/SpecForge/cache/compiled_kernels/4w/a6c1f4a6f24078ac5c758b15e1ac8fb70632f5bd1407a64486e4c64c3bcf9a1a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3514cd4b7efbcfec5a9027a69143f1b0c3ed176a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4w/a6c1f4a6f24078ac5c758b15e1ac8fb70632f5bd1407a64486e4c64c3bcf9a1a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/4w/c4wqg7i4qnegwqmf5zl27i2sqjmjqdudlbpz4efrh6osexb7o6ci.py b/progress/SpecForge/cache/compiled_kernels/4w/c4wqg7i4qnegwqmf5zl27i2sqjmjqdudlbpz4efrh6osexb7o6ci.py new file mode 100644 index 0000000000000000000000000000000000000000..7842c7e75e9cf27437cd9b462b51750c87535a71 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/4w/c4wqg7i4qnegwqmf5zl27i2sqjmjqdudlbpz4efrh6osexb7o6ci.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/53/c53a5oei2rsmuun5kcyppqznwsadzdjf2dy3pco6tft3vwznhvqt.py b/progress/SpecForge/cache/compiled_kernels/53/c53a5oei2rsmuun5kcyppqznwsadzdjf2dy3pco6tft3vwznhvqt.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa578d04794d4358609920636ac5fe95e7c076a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/53/c53a5oei2rsmuun5kcyppqznwsadzdjf2dy3pco6tft3vwznhvqt.py @@ -0,0 +1,879 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iy/ciy2s4y23ebwguemfa5mzklgtf7ftiru2o7wglybbpeorpr3fzwa.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wr/cwrxcuyienn7enb5zk5bybqn72awxq2eifgpxezahby6jelebcni.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ni/cnik5jer2sj77e3dwe5apycrphkqz4kvi3x3p5kyfm4zvmvyf7ek.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream6) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream6) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream6 = get_raw_stream(6) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream6) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 96 + arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 96 + arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg4_1 = 96 + arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg7_1 = 96 + arg8_1 = 96 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5c/53d90dec96369229b6103e5de3881d9618e9bf2b99f3390cf4170e375571ebeb.best_config b/progress/SpecForge/cache/compiled_kernels/5c/53d90dec96369229b6103e5de3881d9618e9bf2b99f3390cf4170e375571ebeb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..38d3904992118dfdfceb98c87234684feb5244a1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5c/53d90dec96369229b6103e5de3881d9618e9bf2b99f3390cf4170e375571ebeb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 58, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5c/8e92ef8ab7a77e4fe957d10110613c2723ae2345100e8d8dbef2fac63c71deb5.best_config b/progress/SpecForge/cache/compiled_kernels/5c/8e92ef8ab7a77e4fe957d10110613c2723ae2345100e8d8dbef2fac63c71deb5.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d8de3b1b0a48196e4b4d72effd86a6bb6a703972 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5c/8e92ef8ab7a77e4fe957d10110613c2723ae2345100e8d8dbef2fac63c71deb5.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 27, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5c/c5caldrch44zpiwt4a2buurde2xbt72zjduivrszj2lwkve6vraz.py b/progress/SpecForge/cache/compiled_kernels/5c/c5caldrch44zpiwt4a2buurde2xbt72zjduivrszj2lwkve6vraz.py new file mode 100644 index 0000000000000000000000000000000000000000..1f18dd897ebec282b797283578f7a2fea0781bc4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5c/c5caldrch44zpiwt4a2buurde2xbt72zjduivrszj2lwkve6vraz.py @@ -0,0 +1,707 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/o7/co7e6rsc4n6p3n2to44nik2swoktbbywea2giuocvvog4a76jlfd.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ad/cadgyqqf5fbg25th4xtgb2wm2frap446ysp5yoz4o2aj7csedfvj.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream5) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream5) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5c/c5cbmaklqhtijnznwcqd24bqwgvkyjaacu7sxp5mgjhbog44ug34.py b/progress/SpecForge/cache/compiled_kernels/5c/c5cbmaklqhtijnznwcqd24bqwgvkyjaacu7sxp5mgjhbog44ug34.py new file mode 100644 index 0000000000000000000000000000000000000000..92c2589a435b68afb0baf82e3122a23fc1703ce5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5c/c5cbmaklqhtijnznwcqd24bqwgvkyjaacu7sxp5mgjhbog44ug34.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/5c/c5cxskvlmc4hq63kmkcexukwuhivuede7r7tzidpzo5llwdedeem.py b/progress/SpecForge/cache/compiled_kernels/5c/c5cxskvlmc4hq63kmkcexukwuhivuede7r7tzidpzo5llwdedeem.py new file mode 100644 index 0000000000000000000000000000000000000000..cb614ba0a154572bf79f0d4de34cbf98552513db --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5c/c5cxskvlmc4hq63kmkcexukwuhivuede7r7tzidpzo5llwdedeem.py @@ -0,0 +1,58 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/5k/c5kiw545j4jh6gzwff5a5mpcfuhjlcvxbrhzjxxfwcjggt7l6rju.py b/progress/SpecForge/cache/compiled_kernels/5k/c5kiw545j4jh6gzwff5a5mpcfuhjlcvxbrhzjxxfwcjggt7l6rju.py new file mode 100644 index 0000000000000000000000000000000000000000..74d1a2142008dc33b755e1e79ae955d06fe78a6b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5k/c5kiw545j4jh6gzwff5a5mpcfuhjlcvxbrhzjxxfwcjggt7l6rju.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/5k/c5kt3f6gms6ukt6ximdfynqc5pvqgs55kd7mkdaztzavraguxh5t.py b/progress/SpecForge/cache/compiled_kernels/5k/c5kt3f6gms6ukt6ximdfynqc5pvqgs55kd7mkdaztzavraguxh5t.py new file mode 100644 index 0000000000000000000000000000000000000000..c24ed8f0e9a562fb99c4cf2e152a5246b8599ee4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5k/c5kt3f6gms6ukt6ximdfynqc5pvqgs55kd7mkdaztzavraguxh5t.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5k/c5kz542etdyxkajzodfawxuwoi4ecb4cnwbqdpyxrrzpxo76j36q.py b/progress/SpecForge/cache/compiled_kernels/5k/c5kz542etdyxkajzodfawxuwoi4ecb4cnwbqdpyxrrzpxo76j36q.py new file mode 100644 index 0000000000000000000000000000000000000000..09312925ef2d0a975c485417767685c57f68fcc6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5k/c5kz542etdyxkajzodfawxuwoi4ecb4cnwbqdpyxrrzpxo76j36q.py @@ -0,0 +1,1019 @@ +# AOT ID: ['3_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/n5/cn5ykoaxxefbhtqa6gkemon2itryrhe2jlchrjsmbb5tt7mje2lt.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream6) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 48 + primals_9 = 48 + primals_2 = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 48), (1536, 48, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 48, 128), (196608, 6144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 48), (1536, 48, 1), device='cuda:6', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5k/cab0354633bccbc45a0e9de99913b97ad7f1b34d114d129a16e73e21fe10ca4a.best_config b/progress/SpecForge/cache/compiled_kernels/5k/cab0354633bccbc45a0e9de99913b97ad7f1b34d114d129a16e73e21fe10ca4a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..26d4b338538085b9ff438ba22ca45b3a0321909a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5k/cab0354633bccbc45a0e9de99913b97ad7f1b34d114d129a16e73e21fe10ca4a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 39, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5l/c5lczyjggplcra6zdc2p7rxdglbrz4gfy7g46ygh3is2av7v2zjj.py b/progress/SpecForge/cache/compiled_kernels/5l/c5lczyjggplcra6zdc2p7rxdglbrz4gfy7g46ygh3is2av7v2zjj.py new file mode 100644 index 0000000000000000000000000000000000000000..82e84281ca41e58fa53deb23e61bd0ea5922e361 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5l/c5lczyjggplcra6zdc2p7rxdglbrz4gfy7g46ygh3is2av7v2zjj.py @@ -0,0 +1,1019 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6o/c6ovzyfo6vkdwwzou6dtdvw7qjf65ifmzpcoltl2nx2xuluryjcy.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cx/ccxglpepxy3o4w7avc734pw6iocwuubrgdrssbwdcl4eo656oxqs.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream0) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 128 + primals_9 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:0', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:0', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5l/c5lt2sqadbnwov5ckqtj4urvvaupfdgpyr7axcx3b7zejitkclxc.py b/progress/SpecForge/cache/compiled_kernels/5l/c5lt2sqadbnwov5ckqtj4urvvaupfdgpyr7axcx3b7zejitkclxc.py new file mode 100644 index 0000000000000000000000000000000000000000..4cbd09b754e35b1c4fb0f0dee02a3d57e565281a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5l/c5lt2sqadbnwov5ckqtj4urvvaupfdgpyr7axcx3b7zejitkclxc.py @@ -0,0 +1,879 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/o5/co5ymfefdgahcwkjo4bf6n435ycdi63hj2qw5lci4mp5y23gvenx.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jo/cjoxenwm7p2jwy3xzklvpkqkmrskyjg666yf62k6ubx4klm5uzrc.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6s/c6sdlh6kboie34frvzfvnvbxpzilpis2fninmoaabs2weqiuckab.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:0" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:0" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream0) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream0) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream0 = get_raw_stream(0) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream0) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 80 + arg1_1 = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 80 + arg3_1 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg4_1 = 80 + arg5_1 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg7_1 = 80 + arg8_1 = 80 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5p/c5pp75cuarhldud4qzzysxrwuqnimky5vcuww7c7xxxx4jpzyknc.py b/progress/SpecForge/cache/compiled_kernels/5p/c5pp75cuarhldud4qzzysxrwuqnimky5vcuww7c7xxxx4jpzyknc.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3cb3c7b9cda5a6f1b78b57d718aaeedf6714d1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5p/c5pp75cuarhldud4qzzysxrwuqnimky5vcuww7c7xxxx4jpzyknc.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5q/8fc440c5839bdae8574245d90adee0dbeaad19440dafa6ce1de8e9734228ce4e.best_config b/progress/SpecForge/cache/compiled_kernels/5q/8fc440c5839bdae8574245d90adee0dbeaad19440dafa6ce1de8e9734228ce4e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..07a69f10191cd277391408d5d264adbb97d972d0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5q/8fc440c5839bdae8574245d90adee0dbeaad19440dafa6ce1de8e9734228ce4e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "H4SXLEKVL4YRMR42RH37LPFTP3TUL362YYXQ7Y6Q7FM6YRQBRWBQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5q/c5qrxjagmrzczlwu37wboybmhbvhgf6j6s3vlys5drj5eo6rjpxf.py b/progress/SpecForge/cache/compiled_kernels/5q/c5qrxjagmrzczlwu37wboybmhbvhgf6j6s3vlys5drj5eo6rjpxf.py new file mode 100644 index 0000000000000000000000000000000000000000..fc06dc84dd79cce2cad9be181c6f2e2c6a53b438 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5q/c5qrxjagmrzczlwu37wboybmhbvhgf6j6s3vlys5drj5eo6rjpxf.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 139264}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 17408 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/5q/c5qxm76eg6wfczhqmdc3gfksgjzrbgtma7q4i4gsyz7dq3yw7ikj.py b/progress/SpecForge/cache/compiled_kernels/5q/c5qxm76eg6wfczhqmdc3gfksgjzrbgtma7q4i4gsyz7dq3yw7ikj.py new file mode 100644 index 0000000000000000000000000000000000000000..d88da1273547d9bdc23faa4afb9bb050b73c686c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5q/c5qxm76eg6wfczhqmdc3gfksgjzrbgtma7q4i4gsyz7dq3yw7ikj.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/5q/d350a9256559a6272599e1d59b94895a102236a6355420d9f2e960a0277c7b53.best_config b/progress/SpecForge/cache/compiled_kernels/5q/d350a9256559a6272599e1d59b94895a102236a6355420d9f2e960a0277c7b53.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5q/d350a9256559a6272599e1d59b94895a102236a6355420d9f2e960a0277c7b53.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5t/c5tpnnnaxvhu2fxdhh6qzp3skfsgzrqsi6dxrevhv236xt5n3mx4.py b/progress/SpecForge/cache/compiled_kernels/5t/c5tpnnnaxvhu2fxdhh6qzp3skfsgzrqsi6dxrevhv236xt5n3mx4.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd3a7a35d093af2d9ecbc89a24cd9fed15a2c64 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5t/c5tpnnnaxvhu2fxdhh6qzp3skfsgzrqsi6dxrevhv236xt5n3mx4.py @@ -0,0 +1,1018 @@ +# AOT ID: ['0_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/td/ctdn3hijvlywgmamgcmc3dewg5hrtelyrdorwcy5pxw43gibtgkp.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 1888, 128][7733248, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 1888, 128][7733248, 241664, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 1888][61440, 1920, 1]cuda:6" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1888, 1888, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 724992, 'r0_': 30932992}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 60416 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 1888) + x1 = xindex // 1888 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zp/czpds2beiaxqyyej5g5alncgbiq5nmsvcq4q5xwoilxffogft2aj.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1888, 128][7733248, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1888, 128][1933312, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1888, 128][1933312, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 1888, 128][7733248, 241664, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 1888, 128][7733248, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 1888, 128][1933312, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_4 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_5] +# %primals_8 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1888, 1888, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 7733248, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1933312, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1933312, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 7733248, 241664, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 7733248, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1933312, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1888 + ZKV = 1 + KV_LEN = 1888 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 15 + stride_q_idx_h = 225 + stride_q_idx_n = 15 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 241664*off_hkv + 1933312*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1888 + KV_LEN = 1888 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1888 + KV_LEN = 1888 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1888, 128), (7733248, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1888, 128), (1933312, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1888, 128), (1933312, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_5, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_6, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_7, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_8, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_9, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_10, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_11, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(getitem, (1, 32, 1888, 128), (7733248, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 1888), (60416, 1888, 1)) + assert_size_stride(tangents_1, (1, 32, 1888, 128), (7733248, 241664, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 1888), (60416, 1888, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((1, 32, 1888), (60416, 1888, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 60416, 128, stream=stream6) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 1888, 128), (7733248, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 1888, 128), (1933312, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 1888, 128), (1933312, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_4, primals_5, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 75, 1, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1888, 128), (7733248, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1888, 128), (1933312, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1888, 128), (1933312, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((1, 32, 1888, 128), (7733248, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1888), (60416, 1888, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1888, 128), (7733248, 241664, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1888), (60416, 1888, 1), device='cuda:6', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/5v/c5v437ejb44taeatgdbhhz6gso35w4st5ae2sbfssl6q6bdhqep2.py b/progress/SpecForge/cache/compiled_kernels/5v/c5v437ejb44taeatgdbhhz6gso35w4st5ae2sbfssl6q6bdhqep2.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4b500c9ad2050887fc6461bf87136aa409295f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5v/c5v437ejb44taeatgdbhhz6gso35w4st5ae2sbfssl6q6bdhqep2.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5x/c5xdgqojur4z4zvrtv44kgcalsk3yhzcmhqhexku5vmlehrc6gsj.py b/progress/SpecForge/cache/compiled_kernels/5x/c5xdgqojur4z4zvrtv44kgcalsk3yhzcmhqhexku5vmlehrc6gsj.py new file mode 100644 index 0000000000000000000000000000000000000000..bb37fe96de1a94f1107ca92c02dd20b5cbe2e0bb --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5x/c5xdgqojur4z4zvrtv44kgcalsk3yhzcmhqhexku5vmlehrc6gsj.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/5y/c5yds4p2gittjvexoxho6pxyercb45z3bwytkoph2nrdjazdnaqn.py b/progress/SpecForge/cache/compiled_kernels/5y/c5yds4p2gittjvexoxho6pxyercb45z3bwytkoph2nrdjazdnaqn.py new file mode 100644 index 0000000000000000000000000000000000000000..b919fed04ca6c9fad47569706daa5b8f3d90c167 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/5y/c5yds4p2gittjvexoxho6pxyercb45z3bwytkoph2nrdjazdnaqn.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6a/b124a2f123b86e178901c31f50b7bf7d2ba53382a64e741c640d9d5484db7fca.best_config b/progress/SpecForge/cache/compiled_kernels/6a/b124a2f123b86e178901c31f50b7bf7d2ba53382a64e741c640d9d5484db7fca.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c56b3ff6df726fa6b67725165b8989b16d820629 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6a/b124a2f123b86e178901c31f50b7bf7d2ba53382a64e741c640d9d5484db7fca.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6a/c6ahrgp7r4tqigkmmvnxij5lsn6k4v6kk2srdfijxf2qpmohaxmh.py b/progress/SpecForge/cache/compiled_kernels/6a/c6ahrgp7r4tqigkmmvnxij5lsn6k4v6kk2srdfijxf2qpmohaxmh.py new file mode 100644 index 0000000000000000000000000000000000000000..088f9aaad7be743d04b6633ab4327eda2cde39fd --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6a/c6ahrgp7r4tqigkmmvnxij5lsn6k4v6kk2srdfijxf2qpmohaxmh.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/6d/b74538074e358a3785183cf20cb7eda9cae89196ed84a95ac3059baec9056de5.best_config b/progress/SpecForge/cache/compiled_kernels/6d/b74538074e358a3785183cf20cb7eda9cae89196ed84a95ac3059baec9056de5.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d131a13c7786a0afd9ad24a5628260bf7969ba42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6d/b74538074e358a3785183cf20cb7eda9cae89196ed84a95ac3059baec9056de5.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6d/c6daw7m3ggtdhg3z7av243dbutkzpxj2x6joynfpyrx23uue5ta4.py b/progress/SpecForge/cache/compiled_kernels/6d/c6daw7m3ggtdhg3z7av243dbutkzpxj2x6joynfpyrx23uue5ta4.py new file mode 100644 index 0000000000000000000000000000000000000000..22386a0ee240d0d3470db150fbde2999c9e0c3ea --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6d/c6daw7m3ggtdhg3z7av243dbutkzpxj2x6joynfpyrx23uue5ta4.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1488, 1488, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 190464*idx_hq + 6094848*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ot/cotrsbintrwogfyqqgnphf2yeyr6yfylbjbvpw7omv42gwfsf4bc.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 380928}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 47616 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1488, 128), (6094848, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 12), (12, 12, 1)) + assert_size_stride(arg4_1, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(arg5_1, (1, 1, 12), (12, 12, 1)) + assert_size_stride(arg6_1, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(arg7_1, (1, 1, 12), (12, 12, 1)) + assert_size_stride(arg8_1, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(arg9_1, (1, 1, 12), (12, 12, 1)) + assert_size_stride(arg10_1, (1, 1, 12, 12), (144, 144, 12, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1488, 128), (6094848, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 12, 1, 32, stream=stream0) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, 47616, stream=stream0) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1488, 128), (6094848, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/6d/c6dg7vz4gdbqxi3w2qjuidwukwzat7wmle6l4qqt6azqfox3i34w.py b/progress/SpecForge/cache/compiled_kernels/6d/c6dg7vz4gdbqxi3w2qjuidwukwzat7wmle6l4qqt6azqfox3i34w.py new file mode 100644 index 0000000000000000000000000000000000000000..651622052542f0445aa01c761b0b045a86d62303 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6d/c6dg7vz4gdbqxi3w2qjuidwukwzat7wmle6l4qqt6azqfox3i34w.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/6e/c6ehlcn5h26qvvbvcwcgnd7qcdu3kr3o2cmulaou4x6ehbkjets2.py b/progress/SpecForge/cache/compiled_kernels/6e/c6ehlcn5h26qvvbvcwcgnd7qcdu3kr3o2cmulaou4x6ehbkjets2.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc79f854c569831c321764d115cb871364a94ca --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6e/c6ehlcn5h26qvvbvcwcgnd7qcdu3kr3o2cmulaou4x6ehbkjets2.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sw/cswwaepmaupqvarnmz6ncii4cmhronmafzughe6w7yjkeroyh33w.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (768, 768, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 98304*idx_hq + 3145728*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jc/cjcydfa56brqsssfuw6ny7n53hdzuh5cl4i2gpdzrzz6k6leiidf.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg4_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg5_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg6_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg7_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg8_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg9_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg10_1, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 6, 1, 32, stream=stream1) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, 24576, stream=stream1) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/6f/c6ftnziemwjiwwsiow7vtoa3h64pd4e6ywxbwrobq77h5wamyg7z.py b/progress/SpecForge/cache/compiled_kernels/6f/c6ftnziemwjiwwsiow7vtoa3h64pd4e6ywxbwrobq77h5wamyg7z.py new file mode 100644 index 0000000000000000000000000000000000000000..cf786b3b1acd4f59ce46d37b93b105a67361c0a5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6f/c6ftnziemwjiwwsiow7vtoa3h64pd4e6ywxbwrobq77h5wamyg7z.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/6h/c6hkxfifxs22so24ve6wkgcqqzi764ulshiucid5fo62ujddp2hy.py b/progress/SpecForge/cache/compiled_kernels/6h/c6hkxfifxs22so24ve6wkgcqqzi764ulshiucid5fo62ujddp2hy.py new file mode 100644 index 0000000000000000000000000000000000000000..47109d48aee021222e0bd159425e51a320cb928b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6h/c6hkxfifxs22so24ve6wkgcqqzi764ulshiucid5fo62ujddp2hy.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 786432, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3145728, 98304, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3145728, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 98304*off_hkv + 786432*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6k/c6kczdbfbgpig7rgesgmdffdqzeulcdjlny2ga3a4a3m2qv4646c.py b/progress/SpecForge/cache/compiled_kernels/6k/c6kczdbfbgpig7rgesgmdffdqzeulcdjlny2ga3a4a3m2qv4646c.py new file mode 100644 index 0000000000000000000000000000000000000000..07514a940e3c805ce7a7f723bb29852aabd4dca1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6k/c6kczdbfbgpig7rgesgmdffdqzeulcdjlny2ga3a4a3m2qv4646c.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7733248, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1933312, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1933312, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1888 + ZKV = 1 + KV_LEN = 1888 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 241664*idx_hq + 7733248*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py b/progress/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py new file mode 100644 index 0000000000000000000000000000000000000000..a55df20e7a37de5560bd0e7c500648aae32af90f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6n/c6n4rf57opno6rcuedu4jk4etcok4ti2tlaztx2ht3z5eydc3vae.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6q/c6qm5zk477k2f67zwfk64l2zc5z2lv6cndjmdk6uj7oobikrxtnf.py b/progress/SpecForge/cache/compiled_kernels/6q/c6qm5zk477k2f67zwfk64l2zc5z2lv6cndjmdk6uj7oobikrxtnf.py new file mode 100644 index 0000000000000000000000000000000000000000..9c73247e633083b298ae8f0ae83a675aa6b63c74 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6q/c6qm5zk477k2f67zwfk64l2zc5z2lv6cndjmdk6uj7oobikrxtnf.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1753088, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 7012352, 219136, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 7012352, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 14 + stride_q_idx_h = 196 + stride_q_idx_n = 14 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 219136*off_hkv + 1753088*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1712 + KV_LEN = 1712 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1712 + KV_LEN = 1712 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/6s/c6sdlh6kboie34frvzfvnvbxpzilpis2fninmoaabs2weqiuckab.py b/progress/SpecForge/cache/compiled_kernels/6s/c6sdlh6kboie34frvzfvnvbxpzilpis2fninmoaabs2weqiuckab.py new file mode 100644 index 0000000000000000000000000000000000000000..2659a312854445ec03538b65647f7394a60cc50e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6s/c6sdlh6kboie34frvzfvnvbxpzilpis2fninmoaabs2weqiuckab.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/6s/c6sslr5g72q7hg34kzid7vs25mj6d7gux5vghpnt7idohpt4rijw.py b/progress/SpecForge/cache/compiled_kernels/6s/c6sslr5g72q7hg34kzid7vs25mj6d7gux5vghpnt7idohpt4rijw.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdf95dee764ecb681966c0bfa64d84feb294e62 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6s/c6sslr5g72q7hg34kzid7vs25mj6d7gux5vghpnt7idohpt4rijw.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (976, 976, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 31232 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_5, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_6, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_7, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_8, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_9, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_10, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_11, (1, 1, 8, 8), (64, 64, 8, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 8, 1, 32, stream=stream7) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, 31232, stream=stream7) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/6s/c6ssul5nzmvqrrcjrod23dfuvdvfjvoy55obxutxwkiy5tgfme6h.py b/progress/SpecForge/cache/compiled_kernels/6s/c6ssul5nzmvqrrcjrod23dfuvdvfjvoy55obxutxwkiy5tgfme6h.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4107f80afd224fcc5d1f43b51eb814719ccb5f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6s/c6ssul5nzmvqrrcjrod23dfuvdvfjvoy55obxutxwkiy5tgfme6h.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6s/d5b96d3a5734461d6cc3d02833f56d2a04bf2e9984c2866619e1f8d5b559f3aa.best_config b/progress/SpecForge/cache/compiled_kernels/6s/d5b96d3a5734461d6cc3d02833f56d2a04bf2e9984c2866619e1f8d5b559f3aa.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5b5ab40b4f49a1b15e3921b106c5a2983d0d4c0f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6s/d5b96d3a5734461d6cc3d02833f56d2a04bf2e9984c2866619e1f8d5b559f3aa.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 91, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/6x/c6xjqivrsy22ckftcnn7eivmrljsf464pq24hrvtrwms2fv6gzff.py b/progress/SpecForge/cache/compiled_kernels/6x/c6xjqivrsy22ckftcnn7eivmrljsf464pq24hrvtrwms2fv6gzff.py new file mode 100644 index 0000000000000000000000000000000000000000..c8121470b895cad727e42383b998ba937756c37f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/6x/c6xjqivrsy22ckftcnn7eivmrljsf464pq24hrvtrwms2fv6gzff.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/by/cbyac7ujf4m27xfhgoydy3wivjfo7hn2stpogjvd2psmfe3ett4g.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1136, 128][4653056, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1136, 128][1163264, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1136, 128][1163264, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1136][36352, 1136, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1136][36352, 1136, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:0" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1136, 1136, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4653056, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1163264, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1163264, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1136 + ZKV = 1 + KV_LEN = 1136 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 145408*idx_hq + 4653056*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/e7/ce7afyeec4c2ygisnhuxors73dv5exizcb53mj2jnatptwiaragm.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1136][36352, 1136, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 290816}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 36352 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1136, 128), (4653056, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1136, 128), (1163264, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1136, 128), (1163264, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg4_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg5_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg6_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg7_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg8_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg9_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg10_1, (1, 1, 9, 9), (81, 81, 9, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, 1136), (36352, 1136, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1136), (36352, 1136, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1136, 128), (4653056, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 9, 1, 32, stream=stream0) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, 36352, stream=stream0) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1136, 128), (4653056, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1136, 128), (1163264, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1136, 128), (1163264, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/75/c754kjoskxv3cqqkmnl5v5i6bmsuwcudcec4lf36webo5fkp64ty.py b/progress/SpecForge/cache/compiled_kernels/75/c754kjoskxv3cqqkmnl5v5i6bmsuwcudcec4lf36webo5fkp64ty.py new file mode 100644 index 0000000000000000000000000000000000000000..70b0b890e5eb8b9cb9b4f6340718cd46f5e484e5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/75/c754kjoskxv3cqqkmnl5v5i6bmsuwcudcec4lf36webo5fkp64ty.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/77/c77answdbsc3cyb3vrut4dwz2fsgr7k7rimvme63nesuaeqlvtxn.py b/progress/SpecForge/cache/compiled_kernels/77/c77answdbsc3cyb3vrut4dwz2fsgr7k7rimvme63nesuaeqlvtxn.py new file mode 100644 index 0000000000000000000000000000000000000000..2af6633594b5fdc19bee75665143b7389ea96828 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/77/c77answdbsc3cyb3vrut4dwz2fsgr7k7rimvme63nesuaeqlvtxn.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/77/c77c3fipfccx76grkmxhxkyw3zcl2ffw4gk4x6bbbxnlnqcvc266.py b/progress/SpecForge/cache/compiled_kernels/77/c77c3fipfccx76grkmxhxkyw3zcl2ffw4gk4x6bbbxnlnqcvc266.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6ae1a50520caecb2beb93b4820028c8387ae2c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/77/c77c3fipfccx76grkmxhxkyw3zcl2ffw4gk4x6bbbxnlnqcvc266.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/7c/c7crmnpas2tenly5xfqty5ujn2qgug5vc2t7otwimn373gxrbtcj.py b/progress/SpecForge/cache/compiled_kernels/7c/c7crmnpas2tenly5xfqty5ujn2qgug5vc2t7otwimn373gxrbtcj.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0e765cf39156649442bc00a9818e6a9e73f224 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7c/c7crmnpas2tenly5xfqty5ujn2qgug5vc2t7otwimn373gxrbtcj.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6f/c6ftnziemwjiwwsiow7vtoa3h64pd4e6ywxbwrobq77h5wamyg7z.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:7" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:7" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sz/cszgqakbiplu4kezgnolrvjqw5rvgv44pau44u5c3nelyopoju4t.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream7) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream7) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1184 + arg1_1 = rand_strided((1, 32, 1184, 128), (4849664, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = 1184 + arg3_1 = rand_strided((1, 8, 1184, 128), (1212416, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg4_1 = 1184 + arg5_1 = rand_strided((1, 8, 1184, 128), (1212416, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg6_1 = 10 + arg7_1 = 10 + arg8_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:7', dtype=torch.int32) + arg9_1 = 1184 + arg10_1 = 1184 + arg11_1 = 10 + arg12_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:7', dtype=torch.int32) + arg13_1 = 10 + arg14_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:7', dtype=torch.int32) + arg15_1 = 10 + arg16_1 = 10 + arg17_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:7', dtype=torch.int32) + arg18_1 = 10 + arg19_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:7', dtype=torch.int32) + arg20_1 = 10 + arg21_1 = 10 + arg22_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:7', dtype=torch.int32) + arg23_1 = 10 + arg24_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:7', dtype=torch.int32) + arg25_1 = 10 + arg26_1 = 10 + arg27_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/7e/95e6a65c1680c61cadcb6b1de6a9bcc3de942edd854204682507d1139982bad3.best_config b/progress/SpecForge/cache/compiled_kernels/7e/95e6a65c1680c61cadcb6b1de6a9bcc3de942edd854204682507d1139982bad3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..21da4474563d7d614bfb841e87c69a5bbfb4251e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7e/95e6a65c1680c61cadcb6b1de6a9bcc3de942edd854204682507d1139982bad3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "O7WDT227PLBDA4RTCWQHRD6TX327YREMCF75BHIDX3W5YOJNWYSQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/7e/c7ennval4wcksoe7vbsaqyaxfyp6e7sxfao65gq43d6h3sl74k2b.py b/progress/SpecForge/cache/compiled_kernels/7e/c7ennval4wcksoe7vbsaqyaxfyp6e7sxfao65gq43d6h3sl74k2b.py new file mode 100644 index 0000000000000000000000000000000000000000..bba3e8bffff532f50886b1c057b65f2ce45ae66c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7e/c7ennval4wcksoe7vbsaqyaxfyp6e7sxfao65gq43d6h3sl74k2b.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py b/progress/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a09dd0b6e0fb4d66827748f76cf479c2779e4a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py b/progress/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py new file mode 100644 index 0000000000000000000000000000000000000000..fec96a75ed4445b9feabe78f02bc85f9ffc72ca7 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7g/c7gxkvfztxetv7w7i4s7mr7dlsdda3dfgq3f3uijvhozq6ggk4o4.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/7j/c7jy7orghqlzyryav6xs6tfj7jilalobanvnpyxyky7yueuhrwlx.py b/progress/SpecForge/cache/compiled_kernels/7j/c7jy7orghqlzyryav6xs6tfj7jilalobanvnpyxyky7yueuhrwlx.py new file mode 100644 index 0000000000000000000000000000000000000000..e753d9943a9f6c9a55465dd9debc050bc22e273b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7j/c7jy7orghqlzyryav6xs6tfj7jilalobanvnpyxyky7yueuhrwlx.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/7m/c7mshqtmfwavmrqyupwyrvidkebn2hsvezg5kt2eukjeqlblnwaf.py b/progress/SpecForge/cache/compiled_kernels/7m/c7mshqtmfwavmrqyupwyrvidkebn2hsvezg5kt2eukjeqlblnwaf.py new file mode 100644 index 0000000000000000000000000000000000000000..54b35ae6d7f251758f4cc32cb5d66bec8a072b78 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7m/c7mshqtmfwavmrqyupwyrvidkebn2hsvezg5kt2eukjeqlblnwaf.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/77/c77answdbsc3cyb3vrut4dwz2fsgr7k7rimvme63nesuaeqlvtxn.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:2" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:2" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/fb/cfbkculzmi75soceyqvz2mtn3lc6czlkb7vqdveqx3cvclvpgynk.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream2) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1040 + arg1_1 = rand_strided((1, 32, 1040, 128), (4259840, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = 1040 + arg3_1 = rand_strided((1, 8, 1040, 128), (1064960, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg4_1 = 1040 + arg5_1 = rand_strided((1, 8, 1040, 128), (1064960, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg6_1 = 9 + arg7_1 = 9 + arg8_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:2', dtype=torch.int32) + arg9_1 = 1040 + arg10_1 = 1040 + arg11_1 = 9 + arg12_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:2', dtype=torch.int32) + arg13_1 = 9 + arg14_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:2', dtype=torch.int32) + arg15_1 = 9 + arg16_1 = 9 + arg17_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:2', dtype=torch.int32) + arg18_1 = 9 + arg19_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:2', dtype=torch.int32) + arg20_1 = 9 + arg21_1 = 9 + arg22_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:2', dtype=torch.int32) + arg23_1 = 9 + arg24_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:2', dtype=torch.int32) + arg25_1 = 9 + arg26_1 = 9 + arg27_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/7s/7125d78dc06caaa38adfda2c7c4bfbae40ebb0515665cc2f5a2b4d525b2be500.best_config b/progress/SpecForge/cache/compiled_kernels/7s/7125d78dc06caaa38adfda2c7c4bfbae40ebb0515665cc2f5a2b4d525b2be500.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cd4183848a025b6b12756cf4012355bb745b489f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7s/7125d78dc06caaa38adfda2c7c4bfbae40ebb0515665cc2f5a2b4d525b2be500.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "R0_BLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 39, "triton_cache_hash": "CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/7s/c7s3gruxs5ilte4q2jq52novut376etxvj3wtbiy4hnaadtemqgt.py b/progress/SpecForge/cache/compiled_kernels/7s/c7s3gruxs5ilte4q2jq52novut376etxvj3wtbiy4hnaadtemqgt.py new file mode 100644 index 0000000000000000000000000000000000000000..a7448b1ff04369c8e185afb5a10fabfe6aeea1b3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7s/c7s3gruxs5ilte4q2jq52novut376etxvj3wtbiy4hnaadtemqgt.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/7w/c7wuvfpczzsfkwsy5ra2l5uzjhob3kfxg67qg3lxntlztgutuwnk.py b/progress/SpecForge/cache/compiled_kernels/7w/c7wuvfpczzsfkwsy5ra2l5uzjhob3kfxg67qg3lxntlztgutuwnk.py new file mode 100644 index 0000000000000000000000000000000000000000..75d80b7eb8279ee5a50683e493202371116cafc4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/7w/c7wuvfpczzsfkwsy5ra2l5uzjhob3kfxg67qg3lxntlztgutuwnk.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/a6/ca6qxhttx7tv3rczw5q4gxtd7weetog4sqll327yn6c2mk3ibsvk.py b/progress/SpecForge/cache/compiled_kernels/a6/ca6qxhttx7tv3rczw5q4gxtd7weetog4sqll327yn6c2mk3ibsvk.py new file mode 100644 index 0000000000000000000000000000000000000000..2285053eb3475cf27b9a631fbc03121f53c4ea5f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/a6/ca6qxhttx7tv3rczw5q4gxtd7weetog4sqll327yn6c2mk3ibsvk.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ug/cugsdz6ngnkagg3e7qtsoxgk3tpt6tu3s7opnd7ctcm5xsdh4xit.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1216, 128][4980736, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1216, 128][1245184, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1216, 128][1245184, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1216][38912, 1216, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1216][38912, 1216, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:6" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:6" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1216, 1216, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4980736, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1245184, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1245184, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1216 + ZKV = 1 + KV_LEN = 1216 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 10 + stride_kv_idx_h = 100 + stride_kv_idx_m = 10 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 155648*idx_hq + 4980736*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yr/cyrolr3isgcukbhhfvhukkk5rnexr7yvexiak2u5gv4at6vg32wy.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1216][38912, 1216, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 311296}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 38912 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1216, 128), (4980736, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1216, 128), (1245184, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1216, 128), (1245184, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 10), (10, 10, 1)) + assert_size_stride(arg4_1, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(arg5_1, (1, 1, 10), (10, 10, 1)) + assert_size_stride(arg6_1, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(arg7_1, (1, 1, 10), (10, 10, 1)) + assert_size_stride(arg8_1, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(arg9_1, (1, 1, 10), (10, 10, 1)) + assert_size_stride(arg10_1, (1, 1, 10, 10), (100, 100, 10, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 1216), (38912, 1216, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1216), (38912, 1216, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1216, 128), (4980736, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 10, 1, 32, stream=stream6) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, 38912, stream=stream6) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1216, 128), (4980736, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1216, 128), (1245184, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1216, 128), (1245184, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:6', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:6', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:6', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:6', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:6', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:6', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:6', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/a7/ca7s7tp233627gzszjumpewmgfr3x27cno6rvah7rvjcudtqof37.py b/progress/SpecForge/cache/compiled_kernels/a7/ca7s7tp233627gzszjumpewmgfr3x27cno6rvah7rvjcudtqof37.py new file mode 100644 index 0000000000000000000000000000000000000000..eeee99ba470147cca7bc7f3bddd870c7ac344fc0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/a7/ca7s7tp233627gzszjumpewmgfr3x27cno6rvah7rvjcudtqof37.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ab/cabnd7sovix3lehr7g2wcvtfl5qeb3jvpm2blacba6y46ltkger7.py b/progress/SpecForge/cache/compiled_kernels/ab/cabnd7sovix3lehr7g2wcvtfl5qeb3jvpm2blacba6y46ltkger7.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d471ed9577fa53e44ab23dfb9c920435e19a52 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ab/cabnd7sovix3lehr7g2wcvtfl5qeb3jvpm2blacba6y46ltkger7.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (624, 624, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 79872*idx_hq + 2555904*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4a/c4anvofd7zpauvjftktjqqrtqgoz7sd5ckg3wf7op3c7cmnbrzz6.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 159744}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 19968 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_5, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_6, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_7, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_8, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_9, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_10, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_11, (1, 1, 5, 5), (25, 25, 5, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 624, 128), (2555904, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 5, 1, 32, stream=stream4) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, 19968, stream=stream4) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ae/caexsdsjli6bhdond3z2qboeu32qkhdr63ftnngmx6jiska5n4cu.py b/progress/SpecForge/cache/compiled_kernels/ae/caexsdsjli6bhdond3z2qboeu32qkhdr63ftnngmx6jiska5n4cu.py new file mode 100644 index 0000000000000000000000000000000000000000..762a4b22599a91de541fc535865c0cd187a10755 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ae/caexsdsjli6bhdond3z2qboeu32qkhdr63ftnngmx6jiska5n4cu.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 999424, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3997696, 124928, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3997696, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 8 + stride_q_idx_h = 64 + stride_q_idx_n = 8 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 124928*off_hkv + 999424*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/am/65f16fd6e06280ad5650406f006f1efe93b73e010e8ef7f2da4fad62c30e989f.best_config b/progress/SpecForge/cache/compiled_kernels/am/65f16fd6e06280ad5650406f006f1efe93b73e010e8ef7f2da4fad62c30e989f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3af6a58642dad18588f3ba5f3c395527e947dca2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/am/65f16fd6e06280ad5650406f006f1efe93b73e010e8ef7f2da4fad62c30e989f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "I32ECYH5I75REV4TRBMWHSFZKPTUF2XC4JIHGP5FR3TBMM6U2PKQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/am/camey3wf35nw3zclfvahr3a2qldzaw7bu3nbioja6kav4yh62xeg.py b/progress/SpecForge/cache/compiled_kernels/am/camey3wf35nw3zclfvahr3a2qldzaw7bu3nbioja6kav4yh62xeg.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9be78d38bbd19ad6836429f5a34b7b543e3fe6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/am/camey3wf35nw3zclfvahr3a2qldzaw7bu3nbioja6kav4yh62xeg.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/progress/SpecForge/cache/compiled_kernels/am/camqma4hqvmxziggaorqeiiobjciqtakh45i2fu3p47njnbhow24.py b/progress/SpecForge/cache/compiled_kernels/am/camqma4hqvmxziggaorqeiiobjciqtakh45i2fu3p47njnbhow24.py new file mode 100644 index 0000000000000000000000000000000000000000..bcab089fad8c2918ea251172262c0f861143e6ce --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/am/camqma4hqvmxziggaorqeiiobjciqtakh45i2fu3p47njnbhow24.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7471104, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1867776, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1867776, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1824 + ZKV = 1 + KV_LEN = 1824 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 233472*idx_hq + 7471104*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/b4/cb4ea7z2vtcaiezoo3wh4zpf75ejd34mkjc6l4u4ckm5daq4tnnx.py b/progress/SpecForge/cache/compiled_kernels/b4/cb4ea7z2vtcaiezoo3wh4zpf75ejd34mkjc6l4u4ckm5daq4tnnx.py new file mode 100644 index 0000000000000000000000000000000000000000..af6ea65b22168c72e0aa9f256969b7c35b3fdfe2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/b4/cb4ea7z2vtcaiezoo3wh4zpf75ejd34mkjc6l4u4ckm5daq4tnnx.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/ba/cbaninj3sj2mvebp5i3voyltfrg6ax5a2t3s36hrhl6klrlwbhlr.py b/progress/SpecForge/cache/compiled_kernels/ba/cbaninj3sj2mvebp5i3voyltfrg6ax5a2t3s36hrhl6klrlwbhlr.py new file mode 100644 index 0000000000000000000000000000000000000000..941d3b1e7307546267397a1d75092b91a087357d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ba/cbaninj3sj2mvebp5i3voyltfrg6ax5a2t3s36hrhl6klrlwbhlr.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/bd/cbdcbcvihdbcyzwk7psebztx3n4mfgcbdwll5wm3ymeikhxitly7.py b/progress/SpecForge/cache/compiled_kernels/bd/cbdcbcvihdbcyzwk7psebztx3n4mfgcbdwll5wm3ymeikhxitly7.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d9798d530695e6c2ddc97ce10fba6b99c18818 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bd/cbdcbcvihdbcyzwk7psebztx3n4mfgcbdwll5wm3ymeikhxitly7.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/bg/cbgyh4dey24b2zp5culpfkeij4cgoe272kimbka3zotzx54g2ful.py b/progress/SpecForge/cache/compiled_kernels/bg/cbgyh4dey24b2zp5culpfkeij4cgoe272kimbka3zotzx54g2ful.py new file mode 100644 index 0000000000000000000000000000000000000000..6be170e1d0a2070716bf2ed08ce13beb022a0280 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bg/cbgyh4dey24b2zp5culpfkeij4cgoe272kimbka3zotzx54g2ful.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/bj/2ce59e77fcf81a124691f1d9b9703e0eb1cf4551dc5cb462f7a7dbb610724e46.best_config b/progress/SpecForge/cache/compiled_kernels/bj/2ce59e77fcf81a124691f1d9b9703e0eb1cf4551dc5cb462f7a7dbb610724e46.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5f6e79cb81182b6b047cbd9ad368e29007e8e196 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bj/2ce59e77fcf81a124691f1d9b9703e0eb1cf4551dc5cb462f7a7dbb610724e46.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "6AG6OALPHM4BFNBPPAU5PR6LOKJOFEIPOLCAMRKFQJZAZPVGODAA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/bj/cbjrt5iaqfrwd545ebmz2cjez2n7o26733qo7khjrhewzsdlew7f.py b/progress/SpecForge/cache/compiled_kernels/bj/cbjrt5iaqfrwd545ebmz2cjez2n7o26733qo7khjrhewzsdlew7f.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e32af5b4f1316fcc36992923ace1affa225b3e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bj/cbjrt5iaqfrwd545ebmz2cjez2n7o26733qo7khjrhewzsdlew7f.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 438272}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 54784 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py b/progress/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py new file mode 100644 index 0000000000000000000000000000000000000000..308d6e7e97e898ac18482834e19a570bc3d1e329 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/br/cda41eee6d5b9cfbff4388fd5323be3ab76c26b52235a1b4a3cbc1d8f6a48049.best_config b/progress/SpecForge/cache/compiled_kernels/br/cda41eee6d5b9cfbff4388fd5323be3ab76c26b52235a1b4a3cbc1d8f6a48049.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/br/cda41eee6d5b9cfbff4388fd5323be3ab76c26b52235a1b4a3cbc1d8f6a48049.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/bu/cbuwktd56uoudprkc2hbpxswoefbyirygxlw77qwsyjspsxv53i7.py b/progress/SpecForge/cache/compiled_kernels/bu/cbuwktd56uoudprkc2hbpxswoefbyirygxlw77qwsyjspsxv53i7.py new file mode 100644 index 0000000000000000000000000000000000000000..4348fcde86e9ec192b1148ca3b47c5e8266d62b9 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bu/cbuwktd56uoudprkc2hbpxswoefbyirygxlw77qwsyjspsxv53i7.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/vj/cvj4vgdwmfi5ohh7mn6zi4ttb7xw7ozyhs3to7ugrnlgkitqtkp2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:6" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:6" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:6" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:6" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yo/cyovrkn5g6t5dss5uyy7ufwyguhqgajhrqatxilbhyxj6rlek4ae.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream6) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream6) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 208 + arg1_1 = rand_strided((1, 32, 208, 128), (851968, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 208 + arg3_1 = rand_strided((1, 8, 208, 128), (212992, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg4_1 = 208 + arg5_1 = rand_strided((1, 8, 208, 128), (212992, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg6_1 = 2 + arg7_1 = 2 + arg8_1 = rand_strided((1, 1, 2, 2), (4, 4, 2, 1), device='cuda:6', dtype=torch.int32) + arg9_1 = 208 + arg10_1 = 208 + arg11_1 = 2 + arg12_1 = rand_strided((1, 1, 2), (2, 2, 1), device='cuda:6', dtype=torch.int32) + arg13_1 = 2 + arg14_1 = rand_strided((1, 1, 2), (2, 2, 1), device='cuda:6', dtype=torch.int32) + arg15_1 = 2 + arg16_1 = 2 + arg17_1 = rand_strided((1, 1, 2, 2), (4, 4, 2, 1), device='cuda:6', dtype=torch.int32) + arg18_1 = 2 + arg19_1 = rand_strided((1, 1, 2), (2, 2, 1), device='cuda:6', dtype=torch.int32) + arg20_1 = 2 + arg21_1 = 2 + arg22_1 = rand_strided((1, 1, 2, 2), (4, 4, 2, 1), device='cuda:6', dtype=torch.int32) + arg23_1 = 2 + arg24_1 = rand_strided((1, 1, 2), (2, 2, 1), device='cuda:6', dtype=torch.int32) + arg25_1 = 2 + arg26_1 = 2 + arg27_1 = rand_strided((1, 1, 2, 2), (4, 4, 2, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/by/cbyac7ujf4m27xfhgoydy3wivjfo7hn2stpogjvd2psmfe3ett4g.py b/progress/SpecForge/cache/compiled_kernels/by/cbyac7ujf4m27xfhgoydy3wivjfo7hn2stpogjvd2psmfe3ett4g.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab68bf518063c26722b645976b4f7d16b1f1915 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/by/cbyac7ujf4m27xfhgoydy3wivjfo7hn2stpogjvd2psmfe3ett4g.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4653056, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1163264, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1163264, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1136 + ZKV = 1 + KV_LEN = 1136 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 145408*idx_hq + 4653056*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/bz/cbzeeqfrpfc7y67gbsotpuhmp4krs376j7wwsxnjo6hqo2gbccst.py b/progress/SpecForge/cache/compiled_kernels/bz/cbzeeqfrpfc7y67gbsotpuhmp4krs376j7wwsxnjo6hqo2gbccst.py new file mode 100644 index 0000000000000000000000000000000000000000..44d0d15baf71b3e04b4e11b84df4905fb7ff9b66 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/bz/cbzeeqfrpfc7y67gbsotpuhmp4krs376j7wwsxnjo6hqo2gbccst.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (976, 976, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6u/c6ulsdn73forgosxqs5bes2cerczsehypg7jodd4snit3gcqp6el.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 249856}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 31232 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_5, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_6, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_7, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_8, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_9, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_10, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_11, (1, 1, 8, 8), (64, 64, 8, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 8, 1, 32, stream=stream7) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, 31232, stream=stream7) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/c4/cc4o3l3d5j7dvyvftvtvcwbesyhq2pwzeojoajmyx6dghtl4owxw.py b/progress/SpecForge/cache/compiled_kernels/c4/cc4o3l3d5j7dvyvftvtvcwbesyhq2pwzeojoajmyx6dghtl4owxw.py new file mode 100644 index 0000000000000000000000000000000000000000..5decbb611681ce8e534620a69420783b083f4830 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/c4/cc4o3l3d5j7dvyvftvtvcwbesyhq2pwzeojoajmyx6dghtl4owxw.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/e5/ce5kwq2j4bjbiesoburv4zkppdydwqwywgfm4bp2c2m2sko7viby.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1888, 128][7733248, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1888, 128][1933312, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1888, 128][1933312, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:6" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:6" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1888, 1888, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7733248, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1933312, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1933312, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1888 + ZKV = 1 + KV_LEN = 1888 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 241664*idx_hq + 7733248*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/y5/cy5jljxgujndttzwbv4ilh6rxm6speqtj427ms2af3tgcjqi4rts.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1888][60416, 1888, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 483328}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 60416 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1888, 128), (7733248, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1888, 128), (1933312, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1888, 128), (1933312, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_5, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_6, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_7, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_8, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_9, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(primals_10, (1, 1, 15), (15, 15, 1)) + assert_size_stride(primals_11, (1, 1, 15, 15), (225, 225, 15, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 1888), (60416, 1888, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1888), (60416, 1888, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1888, 128), (7733248, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 15, 1, 32, stream=stream6) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, 60416, stream=stream6) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1888, 128), (7733248, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1888, 128), (1933312, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1888, 128), (1933312, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/c5/cc5oq7jygergpvto6ee2gvnf2mra6awo63pcxqudcd7ytnj7zodx.py b/progress/SpecForge/cache/compiled_kernels/c5/cc5oq7jygergpvto6ee2gvnf2mra6awo63pcxqudcd7ytnj7zodx.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9ec894efc1656aca4c938ca5c0aba466f4697f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/c5/cc5oq7jygergpvto6ee2gvnf2mra6awo63pcxqudcd7ytnj7zodx.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/c6/cc6f45dc34xsr4g55vytvbpll7g5ovdwld4utnip4s56dzd4nuzw.py b/progress/SpecForge/cache/compiled_kernels/c6/cc6f45dc34xsr4g55vytvbpll7g5ovdwld4utnip4s56dzd4nuzw.py new file mode 100644 index 0000000000000000000000000000000000000000..3041ad6fdc3208d7393a2ea8093d8b10e5acee62 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/c6/cc6f45dc34xsr4g55vytvbpll7g5ovdwld4utnip4s56dzd4nuzw.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ce/ccevljqs5kdl63nuzs6zaycgtfvrqcuzcaf6qedxuj4wuf6pxoso.py b/progress/SpecForge/cache/compiled_kernels/ce/ccevljqs5kdl63nuzs6zaycgtfvrqcuzcaf6qedxuj4wuf6pxoso.py new file mode 100644 index 0000000000000000000000000000000000000000..fb703ccb7f9d9ae9180d88815b42365a627c7f8b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ce/ccevljqs5kdl63nuzs6zaycgtfvrqcuzcaf6qedxuj4wuf6pxoso.py @@ -0,0 +1,1018 @@ +# AOT ID: ['0_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zl/czl7xvjffxknhj474g2xpnuucujqtdridm4gjtkxclxkcxeajpzr.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 624, 128][2555904, 79872, 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (624, 624, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 239616, 'r0_': 10223616}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 19968 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 624) + x1 = xindex // 624 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ry/crylsfwy4n57m7yd6unhb67v2ldprgpyuhxft3lhc6xp2xkvgns5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 624, 128][2555904, 79872, 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_4 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_8 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (624, 624, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 638976, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 2555904, 79872, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 2555904, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 5 + stride_q_idx_h = 25 + stride_q_idx_n = 5 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 79872*off_hkv + 638976*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_5, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_6, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_7, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_8, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_9, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_10, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_11, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(getitem, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 624), (19968, 624, 1)) + assert_size_stride(tangents_1, (1, 32, 624, 128), (2555904, 79872, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 624), (19968, 624, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf1 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 19968, 128, stream=stream4) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 624, 128), (2555904, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 624, 128), (638976, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 624, 128), (638976, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_4, primals_5, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 25, 1, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 624), (19968, 624, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 624, 128), (2555904, 79872, 128, 1), device='cuda:4', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 624), (19968, 624, 1), device='cuda:4', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py b/progress/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py new file mode 100644 index 0000000000000000000000000000000000000000..29c260d35700d86ec7fd8b426a5b9ffa2a9e6e34 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 190464*idx_hq + 6094848*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/cj/c43bed5f9e75971f3eee0a5e390e07f7bf669abeac259bb2e771c881e8597921.best_config b/progress/SpecForge/cache/compiled_kernels/cj/c43bed5f9e75971f3eee0a5e390e07f7bf669abeac259bb2e771c881e8597921.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c090b439cd98894f50ef817b2f3b529d58dbee1e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cj/c43bed5f9e75971f3eee0a5e390e07f7bf669abeac259bb2e771c881e8597921.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 89, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/cj/ccjnuksl3elmgei4is5wcep6ywju2v3wqhpxexxbjwglhwyy7fbx.py b/progress/SpecForge/cache/compiled_kernels/cj/ccjnuksl3elmgei4is5wcep6ywju2v3wqhpxexxbjwglhwyy7fbx.py new file mode 100644 index 0000000000000000000000000000000000000000..3b89dcb99cd677861b326b81d80affe8cdb6087a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cj/ccjnuksl3elmgei4is5wcep6ywju2v3wqhpxexxbjwglhwyy7fbx.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/cl/cclnm6jgey7t7upufkwjnmo56ymha23ivgv2xu5majsar2fkn6pq.py b/progress/SpecForge/cache/compiled_kernels/cl/cclnm6jgey7t7upufkwjnmo56ymha23ivgv2xu5majsar2fkn6pq.py new file mode 100644 index 0000000000000000000000000000000000000000..fef9cee6ed8f564d1f258fe56fab576f2423caed --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cl/cclnm6jgey7t7upufkwjnmo56ymha23ivgv2xu5majsar2fkn6pq.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/cm/ccmcwt34ssxxiolp4xhmnmei4mkzflk4kn4hur3uo2pp2fq56dxm.py b/progress/SpecForge/cache/compiled_kernels/cm/ccmcwt34ssxxiolp4xhmnmei4mkzflk4kn4hur3uo2pp2fq56dxm.py new file mode 100644 index 0000000000000000000000000000000000000000..df7d8adb59b506428344ca56b4720b2bc177837f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cm/ccmcwt34ssxxiolp4xhmnmei4mkzflk4kn4hur3uo2pp2fq56dxm.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7j/c7jy7orghqlzyryav6xs6tfj7jilalobanvnpyxyky7yueuhrwlx.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream0) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 528 + primals_2 = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = 528 + primals_4 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 528 + primals_6 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 5 + primals_8 = 5 + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_10 = 528 + primals_11 = 528 + primals_12 = 5 + primals_13 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_14 = 5 + primals_15 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_16 = 5 + primals_17 = 5 + primals_18 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_19 = 5 + primals_20 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_21 = 5 + primals_22 = 5 + primals_23 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_24 = 5 + primals_25 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_26 = 5 + primals_27 = 5 + primals_28 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/cz/c075ff7db83661cc0fa231438103224a12a7ae19591b80376ef90a86391676d6.best_config b/progress/SpecForge/cache/compiled_kernels/cz/c075ff7db83661cc0fa231438103224a12a7ae19591b80376ef90a86391676d6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf23bb6b2b7bb1100f67493396f7100cd3420a6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cz/c075ff7db83661cc0fa231438103224a12a7ae19591b80376ef90a86391676d6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "LGPZNA72RPSJYHINN2K5UEVKEID3BGMZXX6OKY62QTFBTMK4ZS5Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/cz/ccznsfhqqvjcsiqwieot6ti7mbssrwnaisfi52ehkwx66zkk2vuk.py b/progress/SpecForge/cache/compiled_kernels/cz/ccznsfhqqvjcsiqwieot6ti7mbssrwnaisfi52ehkwx66zkk2vuk.py new file mode 100644 index 0000000000000000000000000000000000000000..152b09063f0be82f7a1712b9422ba4fafbf9e4e6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/cz/ccznsfhqqvjcsiqwieot6ti7mbssrwnaisfi52ehkwx66zkk2vuk.py @@ -0,0 +1,58 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/d3/22a077a3e608253d0da79ba42e7a58c70a6ee9a8441fa5d71af70c383191998e.best_config b/progress/SpecForge/cache/compiled_kernels/d3/22a077a3e608253d0da79ba42e7a58c70a6ee9a8441fa5d71af70c383191998e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..42b3014d947d178f88186eb769b74318572b26a5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/d3/22a077a3e608253d0da79ba42e7a58c70a6ee9a8441fa5d71af70c383191998e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "77M7WJG2OTWWBKLIVTXYS5WS72TBO4MMOVS4LCYXAUYIERVNOMWA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/d3/cd3fyk2an5e67cbkt4zfg2vrfg3ra6vdbivectsfiwzgfqjpyudy.py b/progress/SpecForge/cache/compiled_kernels/d3/cd3fyk2an5e67cbkt4zfg2vrfg3ra6vdbivectsfiwzgfqjpyudy.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9dce92f5571cafba80dca220a14beb5ec16f31 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/d3/cd3fyk2an5e67cbkt4zfg2vrfg3ra6vdbivectsfiwzgfqjpyudy.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/progress/SpecForge/cache/compiled_kernels/d4/cd4cume5h5jawshqqjbr2jblovq432toqs3lgb5cb2kez263sgbp.py b/progress/SpecForge/cache/compiled_kernels/d4/cd4cume5h5jawshqqjbr2jblovq432toqs3lgb5cb2kez263sgbp.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb4171fdb54c607216201c1cae029ae14c2601a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/d4/cd4cume5h5jawshqqjbr2jblovq432toqs3lgb5cb2kez263sgbp.py @@ -0,0 +1,715 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/n5/cn577wod4hnt5ilfoi2wqa2ani6ai3hpj6oyen4zalxgng7t6uko.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ym/cymmhz7ggoire32qvayd4g525ooohofhikb6v7n5vukewzni3xbe.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream7) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream7) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/da/cdasofnd3pbbt2ztadrt7liw6r5jwtyuf5j5jlt43d4d6wrqqs5w.py b/progress/SpecForge/cache/compiled_kernels/da/cdasofnd3pbbt2ztadrt7liw6r5jwtyuf5j5jlt43d4d6wrqqs5w.py new file mode 100644 index 0000000000000000000000000000000000000000..6b556b232b6d9d5828a6b023349c36d00b1318d8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/da/cdasofnd3pbbt2ztadrt7liw6r5jwtyuf5j5jlt43d4d6wrqqs5w.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dc/cdckrc2kp2ofwceybjd3ojozj6ljcw3larn77vgi5rvdtdirp7tj.py b/progress/SpecForge/cache/compiled_kernels/dc/cdckrc2kp2ofwceybjd3ojozj6ljcw3larn77vgi5rvdtdirp7tj.py new file mode 100644 index 0000000000000000000000000000000000000000..b8781d92b39a4ccf658ce751e3f92c89b4d1f132 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dc/cdckrc2kp2ofwceybjd3ojozj6ljcw3larn77vgi5rvdtdirp7tj.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dk/cdkowmfsej5zb5z5dk4mveubks5yfg6z6z7qgigynylbsktyxd53.py b/progress/SpecForge/cache/compiled_kernels/dk/cdkowmfsej5zb5z5dk4mveubks5yfg6z6z7qgigynylbsktyxd53.py new file mode 100644 index 0000000000000000000000000000000000000000..764fd07dfab73c73a6a1e77321c0ab2f169f09b8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dk/cdkowmfsej5zb5z5dk4mveubks5yfg6z6z7qgigynylbsktyxd53.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dl/cdlvhrfrlsdl44ce23aa25hvc36zpvaupgofi6ofja6f4mrv4qwi.py b/progress/SpecForge/cache/compiled_kernels/dl/cdlvhrfrlsdl44ce23aa25hvc36zpvaupgofi6ofja6f4mrv4qwi.py new file mode 100644 index 0000000000000000000000000000000000000000..a71861266f2691b15c2844765ac39f9c213bf577 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dl/cdlvhrfrlsdl44ce23aa25hvc36zpvaupgofi6ofja6f4mrv4qwi.py @@ -0,0 +1,1018 @@ +# AOT ID: ['0_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ck/cckqdjssfjga6vnnhmdsek4txynwhq3yriokhqxjxjhfyc7jmwm3.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 190464, 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 1488][49152, 1536, 1]cuda:0" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1488, 1488, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 571392, 'r0_': 24379392}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 47616 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 1488) + x1 = xindex // 1488 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/mm/cmmzaorjj4g3ymvdtwcwfzc7dewncghsib5pmziiutfy3atqqak6.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 190464, 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_4 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_5] +# %primals_8 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1488, 1488, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1523712, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 6094848, 190464, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 6094848, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 12 + stride_q_idx_h = 144 + stride_q_idx_n = 12 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 190464*off_hkv + 1523712*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1488 + KV_LEN = 1488 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1488 + KV_LEN = 1488 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1488, 128), (6094848, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_5, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_6, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_7, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_8, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_9, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_10, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_11, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(getitem, (1, 32, 1488, 128), (6094848, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 1488), (47616, 1488, 1)) + assert_size_stride(tangents_1, (1, 32, 1488, 128), (6094848, 190464, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 1488), (47616, 1488, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf1 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 47616, 128, stream=stream0) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 1488, 128), (6094848, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 1488, 128), (1523712, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 1488, 128), (1523712, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_4, primals_5, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 60, 1, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1488, 128), (6094848, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((1, 32, 1488, 128), (6094848, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1488), (47616, 1488, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1488, 128), (6094848, 190464, 128, 1), device='cuda:0', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1488), (47616, 1488, 1), device='cuda:0', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py b/progress/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py new file mode 100644 index 0000000000000000000000000000000000000000..782b91adce0a242370c05dfa376b50cc00bd2685 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ds/cdswhsa4g3evztckrmjsna5vwwa237s7turomspqnonilwlogci2.py b/progress/SpecForge/cache/compiled_kernels/ds/cdswhsa4g3evztckrmjsna5vwwa237s7turomspqnonilwlogci2.py new file mode 100644 index 0000000000000000000000000000000000000000..d33722a6f507d835a29a59552f66196977885195 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ds/cdswhsa4g3evztckrmjsna5vwwa237s7turomspqnonilwlogci2.py @@ -0,0 +1,879 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ls/cls5ocxpu5yuevimgi3zbr2b6p42lr4ardundcj2ouobmd7zu7s7.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l5/cl55txm3v3zkxpa3pil7bnjoyoktfcna32knl5c62go2j7vfovty.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/pw/cpwrsw3u5cqvqbtxwdjs2iuqfcanaiyhva2khkrgjw645n62wq35.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 262144, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream6) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream6) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream6 = get_raw_stream(6) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream6) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 64 + arg1_1 = rand_strided((1, 32, 64, 128), (262144, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 64 + arg3_1 = rand_strided((1, 8, 64, 128), (65536, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg4_1 = 64 + arg5_1 = rand_strided((1, 8, 64, 128), (65536, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg7_1 = 64 + arg8_1 = 64 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/dt/cdt7p2o3ok2wfcldv4qsdhjija3xekjiu6ve4x6lyf3bqcrkvbz5.py b/progress/SpecForge/cache/compiled_kernels/dt/cdt7p2o3ok2wfcldv4qsdhjija3xekjiu6ve4x6lyf3bqcrkvbz5.py new file mode 100644 index 0000000000000000000000000000000000000000..ab73a55d0d1dc11c4342ec08e10ac9b4afad7701 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dt/cdt7p2o3ok2wfcldv4qsdhjija3xekjiu6ve4x6lyf3bqcrkvbz5.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dt/cdttccr5ayicj2e2fsdmzhr5vyuw257qijyxlahm353lqqotgupo.py b/progress/SpecForge/cache/compiled_kernels/dt/cdttccr5ayicj2e2fsdmzhr5vyuw257qijyxlahm353lqqotgupo.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff2f098718ead6c16da70bcea8d7fb3493585bc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dt/cdttccr5ayicj2e2fsdmzhr5vyuw257qijyxlahm353lqqotgupo.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dw/cdw5wjkx54bjh766souwq7xccuojrwfvuwen2nweaqgcm7sy5blc.py b/progress/SpecForge/cache/compiled_kernels/dw/cdw5wjkx54bjh766souwq7xccuojrwfvuwen2nweaqgcm7sy5blc.py new file mode 100644 index 0000000000000000000000000000000000000000..a673e4ecf78f5a34224568ce779de055e2146676 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dw/cdw5wjkx54bjh766souwq7xccuojrwfvuwen2nweaqgcm7sy5blc.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/dz/cdzd3gnh5jxzakzskyqhxwjzbb7q6mgo5im4n62dzrsttfdu52w3.py b/progress/SpecForge/cache/compiled_kernels/dz/cdzd3gnh5jxzakzskyqhxwjzbb7q6mgo5im4n62dzrsttfdu52w3.py new file mode 100644 index 0000000000000000000000000000000000000000..b559343b64c43e33fa019bf9baef86dc54e28649 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dz/cdzd3gnh5jxzakzskyqhxwjzbb7q6mgo5im4n62dzrsttfdu52w3.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/dz/cdzyh4cq33vsuc7cr5l5nu7g4cpcwibmuz4nyrgo64ilhvfzonno.py b/progress/SpecForge/cache/compiled_kernels/dz/cdzyh4cq33vsuc7cr5l5nu7g4cpcwibmuz4nyrgo64ilhvfzonno.py new file mode 100644 index 0000000000000000000000000000000000000000..1d36b3e9ab6bf79727452155b6926ac5f29b21b1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/dz/cdzyh4cq33vsuc7cr5l5nu7g4cpcwibmuz4nyrgo64ilhvfzonno.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/or/cor47z6eozmuk4wdukvlmvzymu3noysuwi3iwfctlrm6sdlvjt4r.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ad/cadsewajq2wflihzw6tr3xgecj2ixcsvqwffap7xvueq5e3obe4p.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream4) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream4) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 960 + primals_2 = rand_strided((1, 32, 960, 128), (3932160, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = 960 + primals_4 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_5 = 960 + primals_6 = rand_strided((1, 8, 960, 128), (983040, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = 8 + primals_8 = 8 + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_10 = 960 + primals_11 = 960 + primals_12 = 8 + primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_14 = 8 + primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_16 = 8 + primals_17 = 8 + primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_19 = 8 + primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_21 = 8 + primals_22 = 8 + primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + primals_24 = 8 + primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:4', dtype=torch.int32) + primals_26 = 8 + primals_27 = 8 + primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ea/ceajfbjjrurt6t4led3jw4lyfkf23ku2btxjrnxzwjgjo26pomeq.py b/progress/SpecForge/cache/compiled_kernels/ea/ceajfbjjrurt6t4led3jw4lyfkf23ku2btxjrnxzwjgjo26pomeq.py new file mode 100644 index 0000000000000000000000000000000000000000..a0552c9310439a54f2c5e3f1fbef9d2308494a40 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ea/ceajfbjjrurt6t4led3jw4lyfkf23ku2btxjrnxzwjgjo26pomeq.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2686976, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 671744, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 671744, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 656 + ZKV = 1 + KV_LEN = 656 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 83968*idx_hq + 2686976*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/ec/ceccbij5lfuzxm2pbersrqg5srylk62lnljlwzwpupzqc62kqaqw.py b/progress/SpecForge/cache/compiled_kernels/ec/ceccbij5lfuzxm2pbersrqg5srylk62lnljlwzwpupzqc62kqaqw.py new file mode 100644 index 0000000000000000000000000000000000000000..60d752f3aaad604a8d91411b1828332cfe20c683 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ec/ceccbij5lfuzxm2pbersrqg5srylk62lnljlwzwpupzqc62kqaqw.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 10 + stride_q_idx_h = 100 + stride_q_idx_n = 10 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ef/cef7repagltuignxgm2fqfqsainl45fihynh7brhhbxywg32rsw7.py b/progress/SpecForge/cache/compiled_kernels/ef/cef7repagltuignxgm2fqfqsainl45fihynh7brhhbxywg32rsw7.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd619468d416595acdfd1c1a393e90c5dc84bb4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ef/cef7repagltuignxgm2fqfqsainl45fihynh7brhhbxywg32rsw7.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hk/chkgaqygcy26eau7yvwmovujnogazcg54xtra6atvqxpqdityswt.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gf/cgfi2fecop7lamd73ab667kqzjq4tdvg2or5tdxq5nxbdnilt56y.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:0" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:0" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:0" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:0" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:0" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:0" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_15, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_16, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_17, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_18, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_19, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream0) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 720 + primals_11 = 720 + primals_7 = 6 + primals_8 = 6 + primals_12 = 6 + primals_2 = rand_strided((1, 32, 720, 128), (2949120, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 720, 128), (737280, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 720, 128), (737280, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((1, 32, 720, 128), (2949120, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 720), (23040, 720, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 720, 128), (2949120, 92160, 128, 1), device='cuda:0', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 720), (23040, 720, 1), device='cuda:0', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/eg/cegljumpi4qxy2tkfjauhtir6aympsfoey2neok2c3kyxems73am.py b/progress/SpecForge/cache/compiled_kernels/eg/cegljumpi4qxy2tkfjauhtir6aympsfoey2neok2c3kyxems73am.py new file mode 100644 index 0000000000000000000000000000000000000000..773de155f7631236e0219e13a8d7e0fad35dae54 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/eg/cegljumpi4qxy2tkfjauhtir6aympsfoey2neok2c3kyxems73am.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ei/cei2qreihzzd3pvvhvbqlkichj2eao3xpazmjjdvr7cqq2i2roxx.py b/progress/SpecForge/cache/compiled_kernels/ei/cei2qreihzzd3pvvhvbqlkichj2eao3xpazmjjdvr7cqq2i2roxx.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9345e2f76d1cd5c8de16aed5bd89dfb40c9567 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ei/cei2qreihzzd3pvvhvbqlkichj2eao3xpazmjjdvr7cqq2i2roxx.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2228224, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 557056, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 557056, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 544 + ZKV = 1 + KV_LEN = 544 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 69632*idx_hq + 2228224*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py b/progress/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py new file mode 100644 index 0000000000000000000000000000000000000000..5eafaf759ae9f3477f9b26d04825c137df2f4b30 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ek/cekdygnnt4twwaq4fpapciid2veg5uc5gzwte4mymxe7ertv26cs.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/eo/ceo5bo777l254wspghczpzhp27xk6bfxmnen3ewj2bsfu5nhxaw4.py b/progress/SpecForge/cache/compiled_kernels/eo/ceo5bo777l254wspghczpzhp27xk6bfxmnen3ewj2bsfu5nhxaw4.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b7b734fe9308256ae2dc2cffb3eac4a4c1bace --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/eo/ceo5bo777l254wspghczpzhp27xk6bfxmnen3ewj2bsfu5nhxaw4.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/et/cethqmglrjg222xiohhw5z53iu343havwpgw7crg5kstda2pk5rl.py b/progress/SpecForge/cache/compiled_kernels/et/cethqmglrjg222xiohhw5z53iu343havwpgw7crg5kstda2pk5rl.py new file mode 100644 index 0000000000000000000000000000000000000000..9c32d0e98717860afa19d9e15021664f7d6675fc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/et/cethqmglrjg222xiohhw5z53iu343havwpgw7crg5kstda2pk5rl.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/et/cetrbakv4rvfoerl3oz3yadt5lvdgav53vbkddxj7uzevebcv5w2.py b/progress/SpecForge/cache/compiled_kernels/et/cetrbakv4rvfoerl3oz3yadt5lvdgav53vbkddxj7uzevebcv5w2.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1736ae276a0048c520f9bc6013fb541f385529 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/et/cetrbakv4rvfoerl3oz3yadt5lvdgav53vbkddxj7uzevebcv5w2.py @@ -0,0 +1,721 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/vj/cvj4vgdwmfi5ohh7mn6zi4ttb7xw7ozyhs3to7ugrnlgkitqtkp2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:6" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:6" = PlaceHolder[target=arg8_1] +# %arg13_1 : Tensor "i32[1, 1, 7][7, 7, 1]cuda:6" = PlaceHolder[target=arg13_1] +# %arg14_1 : Tensor "i32[1, 1, 7, 7][49, 49, 7, 1]cuda:6" = PlaceHolder[target=arg14_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg13_1, %arg14_1, %arg15_1, %arg16_1, %arg17_1, %arg18_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xk/cxktzujto4ftosl5gghkudqc4rbz3ix4zcxmbpo5ijvvteuoihei.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg13_1, (1, 1, 7), (7, 7, 1)) + assert_size_stride(arg14_1, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(arg15_1, (1, 1, 7), (7, 7, 1)) + assert_size_stride(arg16_1, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(arg17_1, (1, 1, 7), (7, 7, 1)) + assert_size_stride(arg18_1, (1, 1, 7, 7), (49, 49, 7, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg13_1, arg14_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream6) + del arg12_1 + del arg13_1 + del arg14_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream6) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 848 + arg1_1 = rand_strided((1, 32, 848, 128), (3473408, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 848 + arg3_1 = rand_strided((1, 8, 848, 128), (868352, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg4_1 = 848 + arg5_1 = rand_strided((1, 8, 848, 128), (868352, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg6_1 = 7 + arg7_1 = 7 + arg8_1 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:6', dtype=torch.int32) + arg9_1 = 848 + arg10_1 = 848 + arg11_1 = 7 + arg12_1 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:6', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:6', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:6', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:6', dtype=torch.int32) + arg16_1 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:6', dtype=torch.int32) + arg17_1 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:6', dtype=torch.int32) + arg18_1 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ev/cevamvsxwskhfc5s2mwddfuvfgvcag6ifdqry63nn7reibnpybuz.py b/progress/SpecForge/cache/compiled_kernels/ev/cevamvsxwskhfc5s2mwddfuvfgvcag6ifdqry63nn7reibnpybuz.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0d62a7ae5468003f1d6a89620897a7eba384a9 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ev/cevamvsxwskhfc5s2mwddfuvfgvcag6ifdqry63nn7reibnpybuz.py @@ -0,0 +1,1019 @@ +# AOT ID: ['5_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream6) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 128 + primals_9 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:6', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ew/cewn2ayqifa2nhwgosykqxi6glpfgcflbcihzqk4harxwe5hv5a3.py b/progress/SpecForge/cache/compiled_kernels/ew/cewn2ayqifa2nhwgosykqxi6glpfgcflbcihzqk4harxwe5hv5a3.py new file mode 100644 index 0000000000000000000000000000000000000000..19160efbd33986eb7cd6eed7b4d5efb158a92d0c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ew/cewn2ayqifa2nhwgosykqxi6glpfgcflbcihzqk4harxwe5hv5a3.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/f5/cf5mtbciohvjb5qlwohplikqved5fhg4poa6qi6pjpwttwa6ct3f.py b/progress/SpecForge/cache/compiled_kernels/f5/cf5mtbciohvjb5qlwohplikqved5fhg4poa6qi6pjpwttwa6ct3f.py new file mode 100644 index 0000000000000000000000000000000000000000..41792ed4ee9e3c8b706d3f025919089eccb93cea --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/f5/cf5mtbciohvjb5qlwohplikqved5fhg4poa6qi6pjpwttwa6ct3f.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/li/clinlrn6ucguu7skp4i3gj5ozytngsgi3lwbmsbk3e5v6mhi56ld.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl7z52obzeb6b6wt3rif2ui2hg4y7au6i6bu65dywj7kjitp2m2h.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream3) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream3) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 928 + arg1_1 = rand_strided((1, 32, 928, 128), (3801088, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + arg2_1 = 928 + arg3_1 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg4_1 = 928 + arg5_1 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg6_1 = 8 + arg7_1 = 8 + arg8_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + arg9_1 = 928 + arg10_1 = 928 + arg11_1 = 8 + arg12_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + arg13_1 = 8 + arg14_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + arg15_1 = 8 + arg16_1 = 8 + arg17_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + arg18_1 = 8 + arg19_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + arg20_1 = 8 + arg21_1 = 8 + arg22_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + arg23_1 = 8 + arg24_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + arg25_1 = 8 + arg26_1 = 8 + arg27_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/fd/cfdl4pibln4tirqnyubvpb533hijj7xhvahp4kbfhq3ieoh2isxu.py b/progress/SpecForge/cache/compiled_kernels/fd/cfdl4pibln4tirqnyubvpb533hijj7xhvahp4kbfhq3ieoh2isxu.py new file mode 100644 index 0000000000000000000000000000000000000000..099825563e79898bed9855a8ef8fa63f1ada5e83 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fd/cfdl4pibln4tirqnyubvpb533hijj7xhvahp4kbfhq3ieoh2isxu.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 638976, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 2555904, 79872, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 2555904, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 5 + stride_q_idx_h = 25 + stride_q_idx_n = 5 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 79872*off_hkv + 638976*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/ff/cfffrqcj3ddwf5otvon76vcaycaxtnpswtkcsbyrgevbbr4ecqpp.py b/progress/SpecForge/cache/compiled_kernels/ff/cfffrqcj3ddwf5otvon76vcaycaxtnpswtkcsbyrgevbbr4ecqpp.py new file mode 100644 index 0000000000000000000000000000000000000000..1c3553fac4785919ccd7b9b8ffa640df8a5aaa4b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ff/cfffrqcj3ddwf5otvon76vcaycaxtnpswtkcsbyrgevbbr4ecqpp.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wi/cwiof4iou6xjj5lsqeya5ndeadiieeztoux6qjj6lgcomom2mhjv.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4x/c4xjhgyzut6anhrjeinspoinohfxvyl6skr4gd3vfrscrvsevmya.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream4) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream4) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 736 + arg1_1 = rand_strided((1, 32, 736, 128), (3014656, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = 736 + arg3_1 = rand_strided((1, 8, 736, 128), (753664, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg4_1 = 736 + arg5_1 = rand_strided((1, 8, 736, 128), (753664, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg6_1 = 6 + arg7_1 = 6 + arg8_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:4', dtype=torch.int32) + arg9_1 = 736 + arg10_1 = 736 + arg11_1 = 6 + arg12_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:4', dtype=torch.int32) + arg13_1 = 6 + arg14_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:4', dtype=torch.int32) + arg15_1 = 6 + arg16_1 = 6 + arg17_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:4', dtype=torch.int32) + arg18_1 = 6 + arg19_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:4', dtype=torch.int32) + arg20_1 = 6 + arg21_1 = 6 + arg22_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:4', dtype=torch.int32) + arg23_1 = 6 + arg24_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:4', dtype=torch.int32) + arg25_1 = 6 + arg26_1 = 6 + arg27_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/fl/cflefu7axt6a7sscgk4b3py7yi7rzlb3n3dm24uifa24gvzbuv4e.py b/progress/SpecForge/cache/compiled_kernels/fl/cflefu7axt6a7sscgk4b3py7yi7rzlb3n3dm24uifa24gvzbuv4e.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9bd5c58c60affa4ed11779fc36842d40de00e2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fl/cflefu7axt6a7sscgk4b3py7yi7rzlb3n3dm24uifa24gvzbuv4e.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4980736, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1245184, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1245184, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1216 + ZKV = 1 + KV_LEN = 1216 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 10 + stride_kv_idx_h = 100 + stride_kv_idx_m = 10 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 155648*idx_hq + 4980736*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/fl/cflsk7m2lkh24yw73thyxfxsyjuoyurybfk5lpqvriwatv5g6eyn.py b/progress/SpecForge/cache/compiled_kernels/fl/cflsk7m2lkh24yw73thyxfxsyjuoyurybfk5lpqvriwatv5g6eyn.py new file mode 100644 index 0000000000000000000000000000000000000000..b21143c10ef7e87619e2398350ded8728ad192dc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fl/cflsk7m2lkh24yw73thyxfxsyjuoyurybfk5lpqvriwatv5g6eyn.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/vi/cviuncbgf6gw2ilej3jwbvnvnhxltsoxsxysnqbw2r6nrbtpxkk5.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:1" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:1" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/kc/ckcrxe3eiivcle4pzzfnoxflbgl5wp2ssdfqd2wwvgeyvlpbtiwx.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream1) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 704 + primals_2 = rand_strided((1, 32, 704, 128), (2883584, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = 704 + primals_4 = rand_strided((1, 8, 704, 128), (720896, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = 704 + primals_6 = rand_strided((1, 8, 704, 128), (720896, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = 6 + primals_8 = 6 + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_10 = 704 + primals_11 = 704 + primals_12 = 6 + primals_13 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_14 = 6 + primals_15 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_16 = 6 + primals_17 = 6 + primals_18 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_19 = 6 + primals_20 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_21 = 6 + primals_22 = 6 + primals_23 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_24 = 6 + primals_25 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_26 = 6 + primals_27 = 6 + primals_28 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/fq/cfq7pxuvbsyxbbbfloayuburrfgsb2amghjb4rksnm57d57oz7wn.py b/progress/SpecForge/cache/compiled_kernels/fq/cfq7pxuvbsyxbbbfloayuburrfgsb2amghjb4rksnm57d57oz7wn.py new file mode 100644 index 0000000000000000000000000000000000000000..feaacda01fdc103c66c4fac765a2421b140593a2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fq/cfq7pxuvbsyxbbbfloayuburrfgsb2amghjb4rksnm57d57oz7wn.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/fq/cfq7rf6sdd6jc4wqc2w2e2xhz5or446yhnc5junu7g4u53sc7vfo.py b/progress/SpecForge/cache/compiled_kernels/fq/cfq7rf6sdd6jc4wqc2w2e2xhz5or446yhnc5junu7g4u53sc7vfo.py new file mode 100644 index 0000000000000000000000000000000000000000..adcefeb85d8634dea223112de1352af061c268a6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fq/cfq7rf6sdd6jc4wqc2w2e2xhz5or446yhnc5junu7g4u53sc7vfo.py @@ -0,0 +1,721 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/77/c77answdbsc3cyb3vrut4dwz2fsgr7k7rimvme63nesuaeqlvtxn.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=arg8_1] +# %arg13_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:2" = PlaceHolder[target=arg13_1] +# %arg14_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:2" = PlaceHolder[target=arg14_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg13_1, %arg14_1, %arg15_1, %arg16_1, %arg17_1, %arg18_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/rd/crdz6vb3qazqcygmnqh575uk75hmy45wshrc6bjp3qeelvcm7u3f.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg13_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg14_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg15_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg16_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg17_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg18_1, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg13_1, arg14_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream2) + del arg12_1 + del arg13_1 + del arg14_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 688 + arg1_1 = rand_strided((1, 32, 688, 128), (2818048, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = 688 + arg3_1 = rand_strided((1, 8, 688, 128), (704512, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg4_1 = 688 + arg5_1 = rand_strided((1, 8, 688, 128), (704512, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg6_1 = 6 + arg7_1 = 6 + arg8_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:2', dtype=torch.int32) + arg9_1 = 688 + arg10_1 = 688 + arg11_1 = 6 + arg12_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:2', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:2', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:2', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:2', dtype=torch.int32) + arg16_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:2', dtype=torch.int32) + arg17_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:2', dtype=torch.int32) + arg18_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/fs/cfs2of6ygpacl7vcs74vy37xwnw4avrqcbqon6ht56dira6khkkl.py b/progress/SpecForge/cache/compiled_kernels/fs/cfs2of6ygpacl7vcs74vy37xwnw4avrqcbqon6ht56dira6khkkl.py new file mode 100644 index 0000000000000000000000000000000000000000..39f1708a073416574066bb7ed2d872055489230f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fs/cfs2of6ygpacl7vcs74vy37xwnw4avrqcbqon6ht56dira6khkkl.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/ft/5a950ab694097b1286c7b2b8a8a4c7d39a2fd8e2e9c3aacc079bbbdfb787c47c.best_config b/progress/SpecForge/cache/compiled_kernels/ft/5a950ab694097b1286c7b2b8a8a4c7d39a2fd8e2e9c3aacc079bbbdfb787c47c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ft/5a950ab694097b1286c7b2b8a8a4c7d39a2fd8e2e9c3aacc079bbbdfb787c47c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py b/progress/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py new file mode 100644 index 0000000000000000000000000000000000000000..51b1cd97bbc5bdc8b5fcfab2cc8c38021d2910a9 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ft/cftmlennkcgyn4ynz7zxqohr2jlirziu3mfte3b4eg5y2466jcwm.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/fy/cfyocykti4yhb5mgvyvyw6bnmdpc66bcxv6j5bk5ncbu63xljovz.py b/progress/SpecForge/cache/compiled_kernels/fy/cfyocykti4yhb5mgvyvyw6bnmdpc66bcxv6j5bk5ncbu63xljovz.py new file mode 100644 index 0000000000000000000000000000000000000000..354c85b834201be4a9c9de48153f6c308ec5d107 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/fy/cfyocykti4yhb5mgvyvyw6bnmdpc66bcxv6j5bk5ncbu63xljovz.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gb/cgb3qo2uhbvd22jllo3pcd4qrm2qklyyeuprletllykljldff6hp.py b/progress/SpecForge/cache/compiled_kernels/gb/cgb3qo2uhbvd22jllo3pcd4qrm2qklyyeuprletllykljldff6hp.py new file mode 100644 index 0000000000000000000000000000000000000000..70904a7a9ec2c439207386798c44455e2d064248 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gb/cgb3qo2uhbvd22jllo3pcd4qrm2qklyyeuprletllykljldff6hp.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 219136*idx_hq + 7012352*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/gc/9c80fd6bab99677f7c2ea7fc848c111f28fc7a247554c284ca2bf4cf020e7f8f.best_config b/progress/SpecForge/cache/compiled_kernels/gc/9c80fd6bab99677f7c2ea7fc848c111f28fc7a247554c284ca2bf4cf020e7f8f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c65e5f6854e0be08b79c2f27f87410d51652dc21 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gc/9c80fd6bab99677f7c2ea7fc848c111f28fc7a247554c284ca2bf4cf020e7f8f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 20, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gc/cgcc6zohapv7esv6a7jogw5nsjkk22w2xvind6oatqo3pvk55mee.py b/progress/SpecForge/cache/compiled_kernels/gc/cgcc6zohapv7esv6a7jogw5nsjkk22w2xvind6oatqo3pvk55mee.py new file mode 100644 index 0000000000000000000000000000000000000000..38876111acd04e9904991163aaa47b2e2a10806b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gc/cgcc6zohapv7esv6a7jogw5nsjkk22w2xvind6oatqo3pvk55mee.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16384}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/gc/cgceywn6audeh4gx6twoifhmr6vj6e3pjftmmpzgsfjri6cwdxdv.py b/progress/SpecForge/cache/compiled_kernels/gc/cgceywn6audeh4gx6twoifhmr6vj6e3pjftmmpzgsfjri6cwdxdv.py new file mode 100644 index 0000000000000000000000000000000000000000..578d17fc3735a75e30f9a6705c2de6359fa360d1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gc/cgceywn6audeh4gx6twoifhmr6vj6e3pjftmmpzgsfjri6cwdxdv.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/ge/cgeg6fb7c2vntq3ehpn65bqpr7pq5k4rqrbgdrcvd7ylvpnezy7o.py b/progress/SpecForge/cache/compiled_kernels/ge/cgeg6fb7c2vntq3ehpn65bqpr7pq5k4rqrbgdrcvd7ylvpnezy7o.py new file mode 100644 index 0000000000000000000000000000000000000000..f54600cbee5fa0adc4a6e0612f65203e52d8d9a5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ge/cgeg6fb7c2vntq3ehpn65bqpr7pq5k4rqrbgdrcvd7ylvpnezy7o.py @@ -0,0 +1,876 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/vi/cvizg7evyreummzwbntncwz4vwgv7c4wrvd4ifzggdjmccpbe2oh.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ze/czem3fqfwlqohxvpjx3ojg5cfagoqnqmqo7bqx6oqze4sdb2gel5.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem_1 +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=2] = call_function[target=operator.getitem](args = (%flex_attention, 1), kwargs = {}) +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%getitem_1,%mul_15 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6d/c6dg7vz4gdbqxi3w2qjuidwukwzat7wmle6l4qqt6azqfox3i34w.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, 8, 32, 1, stream=stream6) + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + buf11 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, buf11, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream6) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream6 = get_raw_stream(6) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream6) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf11, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf9, buf10, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 112 + primals_2 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = 112 + primals_4 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = 112 + primals_6 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_8 = 112 + primals_9 = 112 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/gf/cgfi2fecop7lamd73ab667kqzjq4tdvg2or5tdxq5nxbdnilt56y.py b/progress/SpecForge/cache/compiled_kernels/gf/cgfi2fecop7lamd73ab667kqzjq4tdvg2or5tdxq5nxbdnilt56y.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7411d3cbb5b41f8d5dc2dc8995036bf5485e11 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gf/cgfi2fecop7lamd73ab667kqzjq4tdvg2or5tdxq5nxbdnilt56y.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gf/cgfjqp3yzvxlp7ams5la3kknspmfxvblk34bc2gcskwyjy6cxwvr.py b/progress/SpecForge/cache/compiled_kernels/gf/cgfjqp3yzvxlp7ams5la3kknspmfxvblk34bc2gcskwyjy6cxwvr.py new file mode 100644 index 0000000000000000000000000000000000000000..e507ed8cc3a9b7d0a68f8b7e8044f8c0625a9dbe --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gf/cgfjqp3yzvxlp7ams5la3kknspmfxvblk34bc2gcskwyjy6cxwvr.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gm/cgmlz3nxyooiuhc23hyijhc4q5woe52focxxzppau7fyaxzub6pg.py b/progress/SpecForge/cache/compiled_kernels/gm/cgmlz3nxyooiuhc23hyijhc4q5woe52focxxzppau7fyaxzub6pg.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4f7825b85f5ae9a8760d1d4947036653319571 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gm/cgmlz3nxyooiuhc23hyijhc4q5woe52focxxzppau7fyaxzub6pg.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gp/cgp2dfkhh3dgo4jrh7maowcgfq7k26urclql3og7kxyovcn7wvre.py b/progress/SpecForge/cache/compiled_kernels/gp/cgp2dfkhh3dgo4jrh7maowcgfq7k26urclql3og7kxyovcn7wvre.py new file mode 100644 index 0000000000000000000000000000000000000000..1db4f1e31ab31d3d8c3fd9d64e89f236819d509f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gp/cgp2dfkhh3dgo4jrh7maowcgfq7k26urclql3og7kxyovcn7wvre.py @@ -0,0 +1,715 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ba/cbaninj3sj2mvebp5i3voyltfrg6ax5a2t3s36hrhl6klrlwbhlr.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:6" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dj/cdjp4iuumjhdnimksugfso226k4mkq2xdsujxzfvy7rwva7cr2yp.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream6) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream6) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/gr/2c0777f5c46c856fbc5a0e485294f2a700db7463a9606eb9ad14b2249eed01b5.best_config b/progress/SpecForge/cache/compiled_kernels/gr/2c0777f5c46c856fbc5a0e485294f2a700db7463a9606eb9ad14b2249eed01b5.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c56b3ff6df726fa6b67725165b8989b16d820629 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gr/2c0777f5c46c856fbc5a0e485294f2a700db7463a9606eb9ad14b2249eed01b5.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gr/cgrejeacofqq63la2iuk35gh33pidqsdskeydgbrljd5md7pghkc.py b/progress/SpecForge/cache/compiled_kernels/gr/cgrejeacofqq63la2iuk35gh33pidqsdskeydgbrljd5md7pghkc.py new file mode 100644 index 0000000000000000000000000000000000000000..660023a00631e26378c777ea567ea31e0b773d52 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gr/cgrejeacofqq63la2iuk35gh33pidqsdskeydgbrljd5md7pghkc.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/gs/268c7071350d7ae4489d25fbfa593a036d5f97ed76168b6c7b322f8f9e3e90fa.best_config b/progress/SpecForge/cache/compiled_kernels/gs/268c7071350d7ae4489d25fbfa593a036d5f97ed76168b6c7b322f8f9e3e90fa.best_config new file mode 100644 index 0000000000000000000000000000000000000000..000f42315812bb13b78edad0c6bff4a49ad29a2c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gs/268c7071350d7ae4489d25fbfa593a036d5f97ed76168b6c7b322f8f9e3e90fa.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 41, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/gs/cgsdbajyloxe47jtfxdbfdi7ye3lgfn4qhfvprge2u6zknepvixf.py b/progress/SpecForge/cache/compiled_kernels/gs/cgsdbajyloxe47jtfxdbfdi7ye3lgfn4qhfvprge2u6zknepvixf.py new file mode 100644 index 0000000000000000000000000000000000000000..7692a011e0612af22f65d31a455f936588391c4f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/gs/cgsdbajyloxe47jtfxdbfdi7ye3lgfn4qhfvprge2u6zknepvixf.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 262144, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/h2/ch2drkhri22lgifhvi4g5deplfvihiopp72kl2rr3jv6g7sm6f3s.py b/progress/SpecForge/cache/compiled_kernels/h2/ch2drkhri22lgifhvi4g5deplfvihiopp72kl2rr3jv6g7sm6f3s.py new file mode 100644 index 0000000000000000000000000000000000000000..be47a23cabd0c8f068b2efbb5540a529c13aba85 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/h2/ch2drkhri22lgifhvi4g5deplfvihiopp72kl2rr3jv6g7sm6f3s.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4653056, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1163264, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1163264, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1136 + ZKV = 1 + KV_LEN = 1136 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 145408*idx_hq + 4653056*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hb/chb6nflmzovqlxwuawlyyopb6ret64l7cf3lgmxwfk2267e4bpj7.py b/progress/SpecForge/cache/compiled_kernels/hb/chb6nflmzovqlxwuawlyyopb6ret64l7cf3lgmxwfk2267e4bpj7.py new file mode 100644 index 0000000000000000000000000000000000000000..24050e44f93dbe06127e6f83846481dc2e0bbff1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hb/chb6nflmzovqlxwuawlyyopb6ret64l7cf3lgmxwfk2267e4bpj7.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sj/csjgoierbyb7y37xze3raek6tsz2nam2dgfupz3aqsd5fa3pyzto.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sz/csz4oyqya4trc5ll55zjge2jpbfnpn34ci462kngl55bi66jokcx.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream3) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream3) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 928 + primals_2 = rand_strided((1, 32, 928, 128), (3801088, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = 928 + primals_4 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 928 + primals_6 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = 8 + primals_8 = 8 + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_10 = 928 + primals_11 = 928 + primals_12 = 8 + primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_14 = 8 + primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_16 = 8 + primals_17 = 8 + primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_19 = 8 + primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_21 = 8 + primals_22 = 8 + primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_24 = 8 + primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_26 = 8 + primals_27 = 8 + primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hc/892273ed47a1452ddda3ea315a288d4fc3e576c814ec586e94172d561cd1a27c.best_config b/progress/SpecForge/cache/compiled_kernels/hc/892273ed47a1452ddda3ea315a288d4fc3e576c814ec586e94172d561cd1a27c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hc/892273ed47a1452ddda3ea315a288d4fc3e576c814ec586e94172d561cd1a27c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hc/chcapoqvma4vr7nyxct77gqu7ncybp4lj37z3ip24ltc4pdvsp6v.py b/progress/SpecForge/cache/compiled_kernels/hc/chcapoqvma4vr7nyxct77gqu7ncybp4lj37z3ip24ltc4pdvsp6v.py new file mode 100644 index 0000000000000000000000000000000000000000..4656c946b527a4b2e44eb020727e70f2cfb47555 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hc/chcapoqvma4vr7nyxct77gqu7ncybp4lj37z3ip24ltc4pdvsp6v.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/hj/chjiuqziosqwb6jsw22rbpokveo5mfuzvlza6l6lwjc2wou2wm5e.py b/progress/SpecForge/cache/compiled_kernels/hj/chjiuqziosqwb6jsw22rbpokveo5mfuzvlza6l6lwjc2wou2wm5e.py new file mode 100644 index 0000000000000000000000000000000000000000..a48d0b653722ebce88612bd6b3fd013c58430d98 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hj/chjiuqziosqwb6jsw22rbpokveo5mfuzvlza6l6lwjc2wou2wm5e.py @@ -0,0 +1,876 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/bd/cbdcbcvihdbcyzwk7psebztx3n4mfgcbdwll5wm3ymeikhxitly7.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/6a/c6ahrgp7r4tqigkmmvnxij5lsn6k4v6kk2srdfijxf2qpmohaxmh.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem_1 +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf7] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=getitem_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=2] = call_function[target=operator.getitem](args = (%flex_attention, 1), kwargs = {}) +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%getitem_1,%mul_15 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5c/c5cbmaklqhtijnznwcqd24bqwgvkyjaacu7sxp5mgjhbog44ug34.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:2" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:2" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, 8, 32, 1, stream=stream2) + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + buf11 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, buf11, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream2) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream2 = get_raw_stream(2) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream2) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf11, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf9, buf10, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 96 + primals_2 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = 96 + primals_4 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_5 = 96 + primals_6 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_8 = 96 + primals_9 = 96 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hk/ab36029466256c43ea5c438a6a68ba71f374dc7741bbdf37f5fb28fea662f978.best_config b/progress/SpecForge/cache/compiled_kernels/hk/ab36029466256c43ea5c438a6a68ba71f374dc7741bbdf37f5fb28fea662f978.best_config new file mode 100644 index 0000000000000000000000000000000000000000..26d4b338538085b9ff438ba22ca45b3a0321909a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hk/ab36029466256c43ea5c438a6a68ba71f374dc7741bbdf37f5fb28fea662f978.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 39, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hk/chkgaqygcy26eau7yvwmovujnogazcg54xtra6atvqxpqdityswt.py b/progress/SpecForge/cache/compiled_kernels/hk/chkgaqygcy26eau7yvwmovujnogazcg54xtra6atvqxpqdityswt.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfa458f095003e1ad4f8616f14e3f82b284217d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hk/chkgaqygcy26eau7yvwmovujnogazcg54xtra6atvqxpqdityswt.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/hk/chkuge5g5hi37tavjjstol4ku63l7tu6d6mun2mox5bqbwp5hezo.py b/progress/SpecForge/cache/compiled_kernels/hk/chkuge5g5hi37tavjjstol4ku63l7tu6d6mun2mox5bqbwp5hezo.py new file mode 100644 index 0000000000000000000000000000000000000000..137bc3ec2e91965995f8e9821a89f2d7c7238d54 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hk/chkuge5g5hi37tavjjstol4ku63l7tu6d6mun2mox5bqbwp5hezo.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 10 + stride_q_idx_h = 100 + stride_q_idx_n = 10 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/hm/chmowq2puma6v6zqlzr6k2lj7vqksuohhp5jmiaqyxx56jyjg54z.py b/progress/SpecForge/cache/compiled_kernels/hm/chmowq2puma6v6zqlzr6k2lj7vqksuohhp5jmiaqyxx56jyjg54z.py new file mode 100644 index 0000000000000000000000000000000000000000..2864c6d06b8f52daccbe5127134048d0b93a0192 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hm/chmowq2puma6v6zqlzr6k2lj7vqksuohhp5jmiaqyxx56jyjg54z.py @@ -0,0 +1,879 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/kb/ckbjd7tia4titdhs6aqqlxudupbmbujgbumxjcqd3bsm37ed3fra.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5c/c5cxskvlmc4hq63kmkcexukwuhivuede7r7tzidpzo5llwdedeem.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5j/c5j7yk5hlaaxs42qwjlmoczwtoukaw2dio2o6p7qfekdy5upikyv.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:2" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:2" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream2) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream2) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream2 = get_raw_stream(2) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream2) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 96 + arg1_1 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = 96 + arg3_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg4_1 = 96 + arg5_1 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg7_1 = 96 + arg8_1 = 96 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hn/chnb6aeoibwfdajl2bc2brixhsgsiehiaisuunjcavjkmuxf45kd.py b/progress/SpecForge/cache/compiled_kernels/hn/chnb6aeoibwfdajl2bc2brixhsgsiehiaisuunjcavjkmuxf45kd.py new file mode 100644 index 0000000000000000000000000000000000000000..e1621f99a4b37d963926099b4c57e1b9b756e259 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hn/chnb6aeoibwfdajl2bc2brixhsgsiehiaisuunjcavjkmuxf45kd.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hn/chnf6gmeqjkaaeq366fy6labeawktwqqaqbrrgbzlm3bgplzdift.py b/progress/SpecForge/cache/compiled_kernels/hn/chnf6gmeqjkaaeq366fy6labeawktwqqaqbrrgbzlm3bgplzdift.py new file mode 100644 index 0000000000000000000000000000000000000000..25329c69226a65127f15db6ffbc479fb7ca47168 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hn/chnf6gmeqjkaaeq366fy6labeawktwqqaqbrrgbzlm3bgplzdift.py @@ -0,0 +1,707 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yu/cyupab57anmfwvft5q26pgkrbaqaizg55x4klblniianuuv4yuui.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3h/c3h3fb5vqykgr7s3powfrnsc5alooplbijdgjizqo3xq5psavrvz.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream4) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream4) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hn/chnrmksvanmmfsxlkjauwpdumr44nss3o34fnpwaoh3k66btggmy.py b/progress/SpecForge/cache/compiled_kernels/hn/chnrmksvanmmfsxlkjauwpdumr44nss3o34fnpwaoh3k66btggmy.py new file mode 100644 index 0000000000000000000000000000000000000000..e7632f13fc634d7c9bf48d289baebb15ac82942d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hn/chnrmksvanmmfsxlkjauwpdumr44nss3o34fnpwaoh3k66btggmy.py @@ -0,0 +1,1018 @@ +# AOT ID: ['0_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/r7/cr7yqbjyejqfa2ld73pqqr2tp7g6ybrhs4u34xt5s4cusntoze4a.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 768, 128][3145728, 98304, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (768, 768, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 294912, 'r0_': 12582912}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 24576 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 768) + x1 = xindex // 768 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xr/cxr7lbwuek4pgt6vjfn4e5p45m7ggupaspkah63tfuruz7jsob2z.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 768, 128][3145728, 98304, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_4 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_5] +# %primals_8 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (768, 768, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 786432, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3145728, 98304, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3145728, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 98304*off_hkv + 786432*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_5, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_6, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_7, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_8, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_9, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_10, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_11, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(getitem, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 768), (24576, 768, 1)) + assert_size_stride(tangents_1, (1, 32, 768, 128), (3145728, 98304, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 768), (24576, 768, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 24576, 128, stream=stream1) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 768, 128), (786432, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 768, 128), (786432, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_4, primals_5, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 30, 1, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 768), (24576, 768, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 768, 128), (3145728, 98304, 128, 1), device='cuda:1', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 768), (24576, 768, 1), device='cuda:1', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hp/chpbrht2qdbgtwkzmecvfpzvh27f3jovwjt2iytm6cm55dnqd2th.py b/progress/SpecForge/cache/compiled_kernels/hp/chpbrht2qdbgtwkzmecvfpzvh27f3jovwjt2iytm6cm55dnqd2th.py new file mode 100644 index 0000000000000000000000000000000000000000..f3dbe9ce406590318917306c047141c9c22aa1d8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hp/chpbrht2qdbgtwkzmecvfpzvh27f3jovwjt2iytm6cm55dnqd2th.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hp/chpdkf2fjutsit7encshealznhlojx4nrwb5zf3teglvh7rxiij4.py b/progress/SpecForge/cache/compiled_kernels/hp/chpdkf2fjutsit7encshealznhlojx4nrwb5zf3teglvh7rxiij4.py new file mode 100644 index 0000000000000000000000000000000000000000..35f8848f07de91330d3670aeadf283d5db24c322 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hp/chpdkf2fjutsit7encshealznhlojx4nrwb5zf3teglvh7rxiij4.py @@ -0,0 +1,1019 @@ +# AOT ID: ['3_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/rh/crhihk6v6nqpqusr4vp7qqztet42ugdhn524fpi6y32vdhe2bair.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yk/cyktniqut4hvzwm2zxp6c5v72tensuzn3sghmn25du5w3dywb7be.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream3) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 112 + primals_9 = 112 + primals_2 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 112, 128), (458752, 14336, 128, 1), device='cuda:3', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:3', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/hq/10abecf090c356ddab0951297c5ae2ad60f337c929c435b91e8aeeb53861834a.best_config b/progress/SpecForge/cache/compiled_kernels/hq/10abecf090c356ddab0951297c5ae2ad60f337c929c435b91e8aeeb53861834a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..bcc504e4487f2dbcdc7e188f583ca1e290d15ed4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hq/10abecf090c356ddab0951297c5ae2ad60f337c929c435b91e8aeeb53861834a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "2685c2d349c32243d4ee216505dfdf1e257d04d8316595ed69d4ca3499146788", "found_by_coordesc": false, "time_taken_ms": 47, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/hq/chqmxyert4bgtobyyi5cbjgrutmamoo4uk7ilsavd4ghz5g3vd2i.py b/progress/SpecForge/cache/compiled_kernels/hq/chqmxyert4bgtobyyi5cbjgrutmamoo4uk7ilsavd4ghz5g3vd2i.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b1d553ad0130643273a50db1da8e3c8d2e887c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hq/chqmxyert4bgtobyyi5cbjgrutmamoo4uk7ilsavd4ghz5g3vd2i.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/hv/chvswvsk4npg4n2lsriffd75ficguhr3esq2b2ixq3kdvrdkvbfp.py b/progress/SpecForge/cache/compiled_kernels/hv/chvswvsk4npg4n2lsriffd75ficguhr3esq2b2ixq3kdvrdkvbfp.py new file mode 100644 index 0000000000000000000000000000000000000000..8f783079da9fe49cb4363177e8a8f183f1dbe5aa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/hv/chvswvsk4npg4n2lsriffd75ficguhr3esq2b2ixq3kdvrdkvbfp.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/ik/cikoy52g64lcclynnvvpqbzomtal7nvwjxix2eg4kun2c477zpyj.py b/progress/SpecForge/cache/compiled_kernels/ik/cikoy52g64lcclynnvvpqbzomtal7nvwjxix2eg4kun2c477zpyj.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bda2111865e3598c4e0b255f9abcfde5ab8b8b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ik/cikoy52g64lcclynnvvpqbzomtal7nvwjxix2eg4kun2c477zpyj.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ik/ciku2i4i7g7lrh4vv5jtgrwm3j5cuwsm2kbkn7qxbvps74cgtrvs.py b/progress/SpecForge/cache/compiled_kernels/ik/ciku2i4i7g7lrh4vv5jtgrwm3j5cuwsm2kbkn7qxbvps74cgtrvs.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f719d68101fb6f3c4e29da5435e017d835770 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ik/ciku2i4i7g7lrh4vv5jtgrwm3j5cuwsm2kbkn7qxbvps74cgtrvs.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 11 + stride_q_idx_h = 121 + stride_q_idx_n = 11 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iv/3a5d9e222f57efce846ee58aa947d14bff8cd0751cfcdf772242cdabfdd3be51.best_config b/progress/SpecForge/cache/compiled_kernels/iv/3a5d9e222f57efce846ee58aa947d14bff8cd0751cfcdf772242cdabfdd3be51.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iv/3a5d9e222f57efce846ee58aa947d14bff8cd0751cfcdf772242cdabfdd3be51.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iv/civfnuj45ifcjf545fyz4ryg2q422dh2larygr2jr5yyli7gz6ae.py b/progress/SpecForge/cache/compiled_kernels/iv/civfnuj45ifcjf545fyz4ryg2q422dh2larygr2jr5yyli7gz6ae.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f313c77862470d14cef6e7c8d4d764384ab9a3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iv/civfnuj45ifcjf545fyz4ryg2q422dh2larygr2jr5yyli7gz6ae.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/iv/civlbyztu7drslvmp2ko4sbzm3kuineap6ozcy5yx6vlssubqtcx.py b/progress/SpecForge/cache/compiled_kernels/iv/civlbyztu7drslvmp2ko4sbzm3kuineap6ozcy5yx6vlssubqtcx.py new file mode 100644 index 0000000000000000000000000000000000000000..29c7fb6f008783632680d154437d8be20c7eda2c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iv/civlbyztu7drslvmp2ko4sbzm3kuineap6ozcy5yx6vlssubqtcx.py @@ -0,0 +1,1019 @@ +# AOT ID: ['3_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/rh/crhomk4dbfrkitfzvo6fho5x4nowgfrwmpv4nnpqbiigzvcf6dm5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5x/c5xdgqojur4z4zvrtv44kgcalsk3yhzcmhqhexku5vmlehrc6gsj.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream4) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 128 + primals_9 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:4', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:4', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/iv/civspn3mu5mvev3iue5ugsniof2fkp4gxubekpwhysago5i7xsc7.py b/progress/SpecForge/cache/compiled_kernels/iv/civspn3mu5mvev3iue5ugsniof2fkp4gxubekpwhysago5i7xsc7.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c9188a8ad21c15b49d3b172880e108df75a4ce --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iv/civspn3mu5mvev3iue5ugsniof2fkp4gxubekpwhysago5i7xsc7.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/iv/e68226a606cff553afe859fb6dddd4d957b55570d496d324b61da5d0e71a4f51.best_config b/progress/SpecForge/cache/compiled_kernels/iv/e68226a606cff553afe859fb6dddd4d957b55570d496d324b61da5d0e71a4f51.best_config new file mode 100644 index 0000000000000000000000000000000000000000..9e79ec4742dad2efed7e7dfbfa0cc764d0764f83 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iv/e68226a606cff553afe859fb6dddd4d957b55570d496d324b61da5d0e71a4f51.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 100, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iy/b6308ded3063494da0162100390f18771273d1399bce39b871d07c3da2bfae06.best_config b/progress/SpecForge/cache/compiled_kernels/iy/b6308ded3063494da0162100390f18771273d1399bce39b871d07c3da2bfae06.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iy/b6308ded3063494da0162100390f18771273d1399bce39b871d07c3da2bfae06.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iy/ciy2s4y23ebwguemfa5mzklgtf7ftiru2o7wglybbpeorpr3fzwa.py b/progress/SpecForge/cache/compiled_kernels/iy/ciy2s4y23ebwguemfa5mzklgtf7ftiru2o7wglybbpeorpr3fzwa.py new file mode 100644 index 0000000000000000000000000000000000000000..a6deab9d6d13dec4b31280ff0e396c34cccd5f28 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iy/ciy2s4y23ebwguemfa5mzklgtf7ftiru2o7wglybbpeorpr3fzwa.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iy/ciycjvyaco6b3mtyliulibjjjnk2hxipagfayteqs2axxcgg4f2h.py b/progress/SpecForge/cache/compiled_kernels/iy/ciycjvyaco6b3mtyliulibjjjnk2hxipagfayteqs2axxcgg4f2h.py new file mode 100644 index 0000000000000000000000000000000000000000..1411aa2e28eaee7b3645cfbb51a53e85ff118f42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iy/ciycjvyaco6b3mtyliulibjjjnk2hxipagfayteqs2axxcgg4f2h.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/iy/ciyksx73miv4q3buobp6lpr65funn6ikpjyesnfdvjt7n2v4tfmd.py b/progress/SpecForge/cache/compiled_kernels/iy/ciyksx73miv4q3buobp6lpr65funn6ikpjyesnfdvjt7n2v4tfmd.py new file mode 100644 index 0000000000000000000000000000000000000000..7831609abb6217b82022a9de4aacf379d95a0ab1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iy/ciyksx73miv4q3buobp6lpr65funn6ikpjyesnfdvjt7n2v4tfmd.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/iy/ciyzuhstlzrjjhn4kbqqhngdjkri4exl5in452hyq5ctfcf4rex7.py b/progress/SpecForge/cache/compiled_kernels/iy/ciyzuhstlzrjjhn4kbqqhngdjkri4exl5in452hyq5ctfcf4rex7.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb421bd90e7609e76609ad6995f0dbd9a703a41 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/iy/ciyzuhstlzrjjhn4kbqqhngdjkri4exl5in452hyq5ctfcf4rex7.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/vi/cviuncbgf6gw2ilej3jwbvnvnhxltsoxsxysnqbw2r6nrbtpxkk5.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 7][7, 7, 1]cuda:1" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 7, 7][49, 49, 7, 1]cuda:1" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/kc/ckcrxe3eiivcle4pzzfnoxflbgl5wp2ssdfqd2wwvgeyvlpbtiwx.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_15, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(primals_16, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_17, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(primals_18, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_19, (1, 1, 7, 7), (49, 49, 7, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream1) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 896 + primals_2 = rand_strided((1, 32, 896, 128), (3670016, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = 896 + primals_4 = rand_strided((1, 8, 896, 128), (917504, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = 896 + primals_6 = rand_strided((1, 8, 896, 128), (917504, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = 7 + primals_8 = 7 + primals_9 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_10 = 896 + primals_11 = 896 + primals_12 = 7 + primals_13 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/jg/cjgpvyyj2lopal7cmzbvlsv6ri73aolgwrniohxegoderaxqmmzk.py b/progress/SpecForge/cache/compiled_kernels/jg/cjgpvyyj2lopal7cmzbvlsv6ri73aolgwrniohxegoderaxqmmzk.py new file mode 100644 index 0000000000000000000000000000000000000000..be7eceddcbcc7114ef35da42110bfa980e191cc0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jg/cjgpvyyj2lopal7cmzbvlsv6ri73aolgwrniohxegoderaxqmmzk.py @@ -0,0 +1,876 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/a7/ca7s7tp233627gzszjumpewmgfr3x27cno6rvah7rvjcudtqof37.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xe/cxew5mxpxjz73qoybnn4o7yk43aewkli6b766r2424hpsq7cuyyb.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem_1 +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=2] = call_function[target=operator.getitem](args = (%flex_attention, 1), kwargs = {}) +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%getitem_1,%mul_15 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/44/c44m5klhlzg7nfvzfelnbb3hjh2jwzh2e5yyk3vtcvhyw6rbnjo6.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:3" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, 8, 32, 1, stream=stream3) + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + buf11 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, buf11, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream3) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream3 = get_raw_stream(3) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream3) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf11, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf9, buf10, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 112 + primals_2 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = 112 + primals_4 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 112 + primals_6 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_8 = 112 + primals_9 = 112 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/jo/a0ff330db42264a3913ad319b2d45bc3bff38888eef6930a4f4bd7f628329b8a.best_config b/progress/SpecForge/cache/compiled_kernels/jo/a0ff330db42264a3913ad319b2d45bc3bff38888eef6930a4f4bd7f628329b8a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..697e135657e8b1cdff0bd1da718d803688ace9c1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jo/a0ff330db42264a3913ad319b2d45bc3bff38888eef6930a4f4bd7f628329b8a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "RHE7JXFOJQBB2ESJBNT5OZ3PCTRJ7WSJRB7A2GLRM73N3EI7TWDQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/jo/cjoxenwm7p2jwy3xzklvpkqkmrskyjg666yf62k6ubx4klm5uzrc.py b/progress/SpecForge/cache/compiled_kernels/jo/cjoxenwm7p2jwy3xzklvpkqkmrskyjg666yf62k6ubx4klm5uzrc.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2e19eb011a27e13231ea037c3d690e7c590caa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jo/cjoxenwm7p2jwy3xzklvpkqkmrskyjg666yf62k6ubx4klm5uzrc.py @@ -0,0 +1,58 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/jp/7437f3243ef1be7903cf1e93987418b74bd14309e13c65afd8d4e733ee19fb17.best_config b/progress/SpecForge/cache/compiled_kernels/jp/7437f3243ef1be7903cf1e93987418b74bd14309e13c65afd8d4e733ee19fb17.best_config new file mode 100644 index 0000000000000000000000000000000000000000..95c33e2998cf8060ea61b42b0c24eeca276991b2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jp/7437f3243ef1be7903cf1e93987418b74bd14309e13c65afd8d4e733ee19fb17.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "EZ3RKHM23IARJ6OLUDAWBKS54ORODGZICEVX6ZI5AEYV54IQCCLA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/jp/cjpja3r5mbp5nuheqzslcn5gahnsivxrb3wegvw72ojoc6iuj7se.py b/progress/SpecForge/cache/compiled_kernels/jp/cjpja3r5mbp5nuheqzslcn5gahnsivxrb3wegvw72ojoc6iuj7se.py new file mode 100644 index 0000000000000000000000000000000000000000..390f7ea5046a66dcca392e297ffbe69c86c74954 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jp/cjpja3r5mbp5nuheqzslcn5gahnsivxrb3wegvw72ojoc6iuj7se.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 483328}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 60416 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py b/progress/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py new file mode 100644 index 0000000000000000000000000000000000000000..0d90c346667c21a4445c95c4a55d04e6f5938449 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jt/cjtngjzio5oudkq4n4xggwz5enmgujrff3ktfnon7oykgb7as5tu.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/jt/d59aa7257dd95e4a1eaac332210139848ef76db0c409b18376ff56e52adfaf28.best_config b/progress/SpecForge/cache/compiled_kernels/jt/d59aa7257dd95e4a1eaac332210139848ef76db0c409b18376ff56e52adfaf28.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7b46dd53a7c4823eff09151de7fa0b109423744a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jt/d59aa7257dd95e4a1eaac332210139848ef76db0c409b18376ff56e52adfaf28.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 88, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ju/cjuz7wfidlcrsineazvmnsxdrbhkm5youjwj2zn5sa3uu7f7lue3.py b/progress/SpecForge/cache/compiled_kernels/ju/cjuz7wfidlcrsineazvmnsxdrbhkm5youjwj2zn5sa3uu7f7lue3.py new file mode 100644 index 0000000000000000000000000000000000000000..6a3a53c4a60fe177cf3a7577cf8fb9e5b2150496 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ju/cjuz7wfidlcrsineazvmnsxdrbhkm5youjwj2zn5sa3uu7f7lue3.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/jw/cjwtqvxssq5bhbgq4dcc3a5swgw24a7smwlcuuhvmz7ctbt4rdza.py b/progress/SpecForge/cache/compiled_kernels/jw/cjwtqvxssq5bhbgq4dcc3a5swgw24a7smwlcuuhvmz7ctbt4rdza.py new file mode 100644 index 0000000000000000000000000000000000000000..38eb9a6488878f1cdd57ac6f40acedde6e9a90dd --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/jw/cjwtqvxssq5bhbgq4dcc3a5swgw24a7smwlcuuhvmz7ctbt4rdza.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/k5/ck5i3en3gg5fzdujwernjeao3eip2ekakxzrqxixsj7ffy66ed7g.py b/progress/SpecForge/cache/compiled_kernels/k5/ck5i3en3gg5fzdujwernjeao3eip2ekakxzrqxixsj7ffy66ed7g.py new file mode 100644 index 0000000000000000000000000000000000000000..6510af7bb74679b638541d859a4020589333f467 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/k5/ck5i3en3gg5fzdujwernjeao3eip2ekakxzrqxixsj7ffy66ed7g.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1753088, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 7012352, 219136, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 7012352, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 14 + stride_q_idx_h = 196 + stride_q_idx_n = 14 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 219136*off_hkv + 1753088*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1712 + KV_LEN = 1712 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1712 + KV_LEN = 1712 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/k6/ck6dpchfl3ysnaitjfehjf3bflbosiasi5wgge2nsqa5q7yhzvr2.py b/progress/SpecForge/cache/compiled_kernels/k6/ck6dpchfl3ysnaitjfehjf3bflbosiasi5wgge2nsqa5q7yhzvr2.py new file mode 100644 index 0000000000000000000000000000000000000000..8c745e0cec3ef058e2628cb368deece64f235616 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/k6/ck6dpchfl3ysnaitjfehjf3bflbosiasi5wgge2nsqa5q7yhzvr2.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/k7/ck7pjrcyc4oktjtivu4jbbkw4unreb6r66sdeh4xglza5tlv75vo.py b/progress/SpecForge/cache/compiled_kernels/k7/ck7pjrcyc4oktjtivu4jbbkw4unreb6r66sdeh4xglza5tlv75vo.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbb62a9ed31da21bb62c34c75c8e9aa87049eff --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/k7/ck7pjrcyc4oktjtivu4jbbkw4unreb6r66sdeh4xglza5tlv75vo.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 5570560, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1392640, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1392640, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1360 + ZKV = 1 + KV_LEN = 1360 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 11 + stride_kv_idx_h = 121 + stride_kv_idx_m = 11 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 174080*idx_hq + 5570560*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/kb/ckbjd7tia4titdhs6aqqlxudupbmbujgbumxjcqd3bsm37ed3fra.py b/progress/SpecForge/cache/compiled_kernels/kb/ckbjd7tia4titdhs6aqqlxudupbmbujgbumxjcqd3bsm37ed3fra.py new file mode 100644 index 0000000000000000000000000000000000000000..326a0f049fac15bf6a2b8e85f0ad7d83938c0307 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kb/ckbjd7tia4titdhs6aqqlxudupbmbujgbumxjcqd3bsm37ed3fra.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/kd/ckdahpjmmxvwpaza3ropblwkthkgakfpivw5wj2y3tycy34bggtf.py b/progress/SpecForge/cache/compiled_kernels/kd/ckdahpjmmxvwpaza3ropblwkthkgakfpivw5wj2y3tycy34bggtf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5e1c0289391f36d16856f4fe1dbbe18bf1b75f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kd/ckdahpjmmxvwpaza3ropblwkthkgakfpivw5wj2y3tycy34bggtf.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ka/cka7xx5r5hjbxmpyy422mg5htwwxacunbdxsuwsvx5i3hsleezxl.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nz/cnz3o2q2vb2ti2gyyevhvaopio72peoav2vt76s5rmoxvuy4zs6c.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:2" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:2" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:2" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:2" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:2" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream2) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1600 + primals_11 = 1600 + primals_7 = 13 + primals_8 = 13 + primals_12 = 13 + primals_14 = 13 + primals_16 = 13 + primals_17 = 13 + primals_19 = 13 + primals_22 = 13 + primals_21 = 13 + primals_24 = 13 + primals_27 = 13 + primals_26 = 13 + primals_2 = rand_strided((1, 32, 1600, 128), (6553600, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 1600, 128), (1638400, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 1600, 128), (1638400, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((1, 32, 1600, 128), (6553600, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1600), (51200, 1600, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1600, 128), (6553600, 204800, 128, 1), device='cuda:2', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1600), (51200, 1600, 1), device='cuda:2', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/kg/ckgf6ud7veeoti7gvwpatfrcoqzzyfqtpcvitdqjb4nh3bl4xkga.py b/progress/SpecForge/cache/compiled_kernels/kg/ckgf6ud7veeoti7gvwpatfrcoqzzyfqtpcvitdqjb4nh3bl4xkga.py new file mode 100644 index 0000000000000000000000000000000000000000..5948a7b465ead164fe1d89f277f3836311e2e481 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kg/ckgf6ud7veeoti7gvwpatfrcoqzzyfqtpcvitdqjb4nh3bl4xkga.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sc/csc67q546j7muhvoccled2oa4u3qywv2ugybi3n244etkkzozhm2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1104, 128][4521984, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1104, 1104, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4521984, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1130496, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1130496, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1104 + ZKV = 1 + KV_LEN = 1104 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 141312*idx_hq + 4521984*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zr/czr6mneyujqwtcnujzpzrifkj4z36xtvmjyymecmrkgwf7765fr3.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 282624}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 35328 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1104, 128), (4521984, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_5, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_6, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_7, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_8, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_9, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_10, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_11, (1, 1, 9, 9), (81, 81, 9, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, 1104), (35328, 1104, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1104), (35328, 1104, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1104, 128), (4521984, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 9, 1, 32, stream=stream3) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_poi_fused_mul_1.run(buf0, buf5, 35328, stream=stream3) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/kw/ckwbl7bxz2u7dcyoraiz5hhb5lewounuhgdbac4eqroe5ldc7au4.py b/progress/SpecForge/cache/compiled_kernels/kw/ckwbl7bxz2u7dcyoraiz5hhb5lewounuhgdbac4eqroe5ldc7au4.py new file mode 100644 index 0000000000000000000000000000000000000000..08347e131cb11f10fdc7a91d5caca57a5515aeee --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kw/ckwbl7bxz2u7dcyoraiz5hhb5lewounuhgdbac4eqroe5ldc7au4.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7471104, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1867776, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1867776, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1824 + ZKV = 1 + KV_LEN = 1824 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 233472*idx_hq + 7471104*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/kw/ckwwca6uq5lntof6nzoke2sd3e4luvsf66256uzocg6dxanepczq.py b/progress/SpecForge/cache/compiled_kernels/kw/ckwwca6uq5lntof6nzoke2sd3e4luvsf66256uzocg6dxanepczq.py new file mode 100644 index 0000000000000000000000000000000000000000..1a35134a97c5e773bfca83aad63a27ed35778a73 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kw/ckwwca6uq5lntof6nzoke2sd3e4luvsf66256uzocg6dxanepczq.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 219136*idx_hq + 7012352*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/kz/ckzlcn5lsemwkkbb4yz5zuni7nbpdavabua7uvpppxufqmezcy3s.py b/progress/SpecForge/cache/compiled_kernels/kz/ckzlcn5lsemwkkbb4yz5zuni7nbpdavabua7uvpppxufqmezcy3s.py new file mode 100644 index 0000000000000000000000000000000000000000..edd8f489b404e43312d4975a293d0c51aef10910 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/kz/ckzlcn5lsemwkkbb4yz5zuni7nbpdavabua7uvpppxufqmezcy3s.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/l3/cl36lruem3eh37k6mle2ayydiuxwzpeokqgdl2gwgf7j4pa7gule.py b/progress/SpecForge/cache/compiled_kernels/l3/cl36lruem3eh37k6mle2ayydiuxwzpeokqgdl2gwgf7j4pa7gule.py new file mode 100644 index 0000000000000000000000000000000000000000..7f34ceab580dda8d643198f1146a8d14e720cf5e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l3/cl36lruem3eh37k6mle2ayydiuxwzpeokqgdl2gwgf7j4pa7gule.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/l4/cl43mafvis4joj6hsgpkw4ytnx3pmzzno7y553vh2qnyn37z3gsp.py b/progress/SpecForge/cache/compiled_kernels/l4/cl43mafvis4joj6hsgpkw4ytnx3pmzzno7y553vh2qnyn37z3gsp.py new file mode 100644 index 0000000000000000000000000000000000000000..2c09c57096fb945130a5f0161774ca3cd8239856 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l4/cl43mafvis4joj6hsgpkw4ytnx3pmzzno7y553vh2qnyn37z3gsp.py @@ -0,0 +1,1019 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zl/czl55bstkzaiyo63lnwmmvyroovbivby4d7fixeupckdyextlwhk.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3z/c3zi2pt6zmbthc6ythgt5p4ednhp6m24gpscb2pt6adf6xojetua.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream5) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 112 + primals_9 = 112 + primals_2 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 112, 128), (458752, 14336, 128, 1), device='cuda:5', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:5', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/l6/87089aa230f9f7f89c2427c3012d9a3a71509e68c85ed05f6e20175d4c1c58a3.best_config b/progress/SpecForge/cache/compiled_kernels/l6/87089aa230f9f7f89c2427c3012d9a3a71509e68c85ed05f6e20175d4c1c58a3.best_config new file mode 100644 index 0000000000000000000000000000000000000000..e938183783c79c7db17d698edd39f7fdd28ad991 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l6/87089aa230f9f7f89c2427c3012d9a3a71509e68c85ed05f6e20175d4c1c58a3.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 76, "triton_cache_hash": "FR47NT5CNCUU6ARP3WHN7ECD37NFCS6GJTDCEIBLTMP23LKCYSWA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/l6/cl6xhyxtqtnczpdlmb3w5oimiq4gvtasqjswbgz4yakhnh4qkps7.py b/progress/SpecForge/cache/compiled_kernels/l6/cl6xhyxtqtnczpdlmb3w5oimiq4gvtasqjswbgz4yakhnh4qkps7.py new file mode 100644 index 0000000000000000000000000000000000000000..16c929417c4ff92be4de8e46288c4e4ef56bb010 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l6/cl6xhyxtqtnczpdlmb3w5oimiq4gvtasqjswbgz4yakhnh4qkps7.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 233472}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 29184 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/l7/186261f5a20b258de55632c9b06206aed865aba375edf8ad8f246f8c18e53e3d.best_config b/progress/SpecForge/cache/compiled_kernels/l7/186261f5a20b258de55632c9b06206aed865aba375edf8ad8f246f8c18e53e3d.best_config new file mode 100644 index 0000000000000000000000000000000000000000..1ac757b0095439292d408c3745fae2886e918dff --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l7/186261f5a20b258de55632c9b06206aed865aba375edf8ad8f246f8c18e53e3d.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py b/progress/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py new file mode 100644 index 0000000000000000000000000000000000000000..9eff8e905ee403a47073dda6eb914b964f4f3669 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py b/progress/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py new file mode 100644 index 0000000000000000000000000000000000000000..3d062f922d963111a53bc6c58c988ff2b849c81e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l7/cl76p6rje3cyrrbyvxjjj7oxbieltfs4p5xqjre35l6wnofhynby.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/l7/cl7mi2m6hh35ktke2amgxxvneopsr4tc635ultdyz6bpai22e6dc.py b/progress/SpecForge/cache/compiled_kernels/l7/cl7mi2m6hh35ktke2amgxxvneopsr4tc635ultdyz6bpai22e6dc.py new file mode 100644 index 0000000000000000000000000000000000000000..e7349ef335c208f67746fa8dc125a24a56eba622 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l7/cl7mi2m6hh35ktke2amgxxvneopsr4tc635ultdyz6bpai22e6dc.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2228224, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 557056, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 557056, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 544 + ZKV = 1 + KV_LEN = 544 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 69632*idx_hq + 2228224*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/l7/cl7z52obzeb6b6wt3rif2ui2hg4y7au6i6bu65dywj7kjitp2m2h.py b/progress/SpecForge/cache/compiled_kernels/l7/cl7z52obzeb6b6wt3rif2ui2hg4y7au6i6bu65dywj7kjitp2m2h.py new file mode 100644 index 0000000000000000000000000000000000000000..70aa57c3e4c19a6f21014c0e38734b2a4d9508be --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/l7/cl7z52obzeb6b6wt3rif2ui2hg4y7au6i6bu65dywj7kjitp2m2h.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/ld/cldwkxdv7izigkhcjor2j2snktzj2wsktvhftnyigkxfdpjoqggz.py b/progress/SpecForge/cache/compiled_kernels/ld/cldwkxdv7izigkhcjor2j2snktzj2wsktvhftnyigkxfdpjoqggz.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b6823a89f24a39fabfc1e8b0aa00ba13b1ff56 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ld/cldwkxdv7izigkhcjor2j2snktzj2wsktvhftnyigkxfdpjoqggz.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 5570560, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1392640, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1392640, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1360 + ZKV = 1 + KV_LEN = 1360 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 11 + stride_kv_idx_h = 121 + stride_kv_idx_m = 11 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 174080*idx_hq + 5570560*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/li/clinlrn6ucguu7skp4i3gj5ozytngsgi3lwbmsbk3e5v6mhi56ld.py b/progress/SpecForge/cache/compiled_kernels/li/clinlrn6ucguu7skp4i3gj5ozytngsgi3lwbmsbk3e5v6mhi56ld.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d8dff0e099a231b7533afff2f5c4ba6be1ccc0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/li/clinlrn6ucguu7skp4i3gj5ozytngsgi3lwbmsbk3e5v6mhi56ld.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/li/cliwuwghfjyoz2lajqh2nrqenj4dkn6hghy7qyrzmt44ucfqu5rv.py b/progress/SpecForge/cache/compiled_kernels/li/cliwuwghfjyoz2lajqh2nrqenj4dkn6hghy7qyrzmt44ucfqu5rv.py new file mode 100644 index 0000000000000000000000000000000000000000..095c552294907d355d8dfeec11059faf4f3303a7 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/li/cliwuwghfjyoz2lajqh2nrqenj4dkn6hghy7qyrzmt44ucfqu5rv.py @@ -0,0 +1,715 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ck/cckbx73nknzvdgfzjsnagur7f5trpu2rm7axhe3m7etjboxosqbr.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:5" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:5" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/w4/cw4fsmrzlz3yxap27ofc3ad2deoo2c6ok5s5m63cejtspevv4x7z.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream5) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream5) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:5', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/lk/clk4vu5ikxjkumv7kyyjbhc7rinhokcxh5djfto6vndhpocrouas.py b/progress/SpecForge/cache/compiled_kernels/lk/clk4vu5ikxjkumv7kyyjbhc7rinhokcxh5djfto6vndhpocrouas.py new file mode 100644 index 0000000000000000000000000000000000000000..5902f984fe8cb1b722870e7e8ea50b58fb27e86d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/lk/clk4vu5ikxjkumv7kyyjbhc7rinhokcxh5djfto6vndhpocrouas.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1523712, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 6094848, 190464, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 6094848, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 12 + stride_q_idx_h = 144 + stride_q_idx_n = 12 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 190464*off_hkv + 1523712*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1488 + KV_LEN = 1488 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1488 + KV_LEN = 1488 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/lr/clrmkkwg5hovgsk35oip3w4ae2haqifxlbu53ijlrjgd4u6j5cve.py b/progress/SpecForge/cache/compiled_kernels/lr/clrmkkwg5hovgsk35oip3w4ae2haqifxlbu53ijlrjgd4u6j5cve.py new file mode 100644 index 0000000000000000000000000000000000000000..3f64ed743eb8c9f95bc65586e370ead23b8b9938 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/lr/clrmkkwg5hovgsk35oip3w4ae2haqifxlbu53ijlrjgd4u6j5cve.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/lr/f72a99f7662315c5b657245797cd81be4cf7fc45f97f84d52fa93f6672d8922c.best_config b/progress/SpecForge/cache/compiled_kernels/lr/f72a99f7662315c5b657245797cd81be4cf7fc45f97f84d52fa93f6672d8922c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7e00b0ff230fa84cffc0ddb776fb2d485384a5d7 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/lr/f72a99f7662315c5b657245797cd81be4cf7fc45f97f84d52fa93f6672d8922c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ls/cls5ocxpu5yuevimgi3zbr2b6p42lr4ardundcj2ouobmd7zu7s7.py b/progress/SpecForge/cache/compiled_kernels/ls/cls5ocxpu5yuevimgi3zbr2b6p42lr4ardundcj2ouobmd7zu7s7.py new file mode 100644 index 0000000000000000000000000000000000000000..082bad985348938b604f97586a94c5aae0a87c2f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ls/cls5ocxpu5yuevimgi3zbr2b6p42lr4ardundcj2ouobmd7zu7s7.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/lw/clwi4dqbrij6ek7edfo5fy635sclo6xa5a5yx4tuirr5mullbmva.py b/progress/SpecForge/cache/compiled_kernels/lw/clwi4dqbrij6ek7edfo5fy635sclo6xa5a5yx4tuirr5mullbmva.py new file mode 100644 index 0000000000000000000000000000000000000000..51df7c68ecadc758294f827bccb0ccbec34d4f95 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/lw/clwi4dqbrij6ek7edfo5fy635sclo6xa5a5yx4tuirr5mullbmva.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ly/clywhyti5tzyelmxzqoxv3zabc3epa6h23oypsztralfl4aatg5q.py b/progress/SpecForge/cache/compiled_kernels/ly/clywhyti5tzyelmxzqoxv3zabc3epa6h23oypsztralfl4aatg5q.py new file mode 100644 index 0000000000000000000000000000000000000000..91deaf1efc56eac5eeeca6624f2ccd8f683e798c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ly/clywhyti5tzyelmxzqoxv3zabc3epa6h23oypsztralfl4aatg5q.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sm/csmz2k3edklza74y46mapw3rjeds7hufvjfwpaldv36ivikgh5kl.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:5" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:5" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/un/cunvnnzocm6jktkvxmstfwiwg5c25oosgsdlibhsk55izrv74lgj.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16384}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream5) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream5) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 496 + primals_2 = rand_strided((1, 32, 496, 128), (2031616, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = 496 + primals_4 = rand_strided((1, 8, 496, 128), (507904, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 496 + primals_6 = rand_strided((1, 8, 496, 128), (507904, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = 4 + primals_8 = 4 + primals_9 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_10 = 496 + primals_11 = 496 + primals_12 = 4 + primals_13 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_14 = 4 + primals_15 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_16 = 4 + primals_17 = 4 + primals_18 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_19 = 4 + primals_20 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_21 = 4 + primals_22 = 4 + primals_23 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_24 = 4 + primals_25 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_26 = 4 + primals_27 = 4 + primals_28 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/m4/2f7479a4fe52807762be6fbd208e0be18280aae0c5600ce7f5049803f99f432e.best_config b/progress/SpecForge/cache/compiled_kernels/m4/2f7479a4fe52807762be6fbd208e0be18280aae0c5600ce7f5049803f99f432e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..21da4474563d7d614bfb841e87c69a5bbfb4251e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/m4/2f7479a4fe52807762be6fbd208e0be18280aae0c5600ce7f5049803f99f432e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "O7WDT227PLBDA4RTCWQHRD6TX327YREMCF75BHIDX3W5YOJNWYSQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/m4/cm4u4rta6keduxwyvievxqnr7ndyhk7d4v4xus57imdoqu6nxwsa.py b/progress/SpecForge/cache/compiled_kernels/m4/cm4u4rta6keduxwyvievxqnr7ndyhk7d4v4xus57imdoqu6nxwsa.py new file mode 100644 index 0000000000000000000000000000000000000000..500f5a5a9f12eb55a9f023ab73a8b807e76258d6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/m4/cm4u4rta6keduxwyvievxqnr7ndyhk7d4v4xus57imdoqu6nxwsa.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/md/cmdeweoklzsy3ozvcu5a5bpcgl6evj274beoukzvay3xz6dnm543.py b/progress/SpecForge/cache/compiled_kernels/md/cmdeweoklzsy3ozvcu5a5bpcgl6evj274beoukzvay3xz6dnm543.py new file mode 100644 index 0000000000000000000000000000000000000000..8f276a06da0c594c0c884c8d5d707544949d0dc9 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/md/cmdeweoklzsy3ozvcu5a5bpcgl6evj274beoukzvay3xz6dnm543.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/n5/a3c76fb8de220d615f5d52f432447708741a1c8571d93d3d67b229f48f9d91f9.best_config b/progress/SpecForge/cache/compiled_kernels/n5/a3c76fb8de220d615f5d52f432447708741a1c8571d93d3d67b229f48f9d91f9.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/n5/a3c76fb8de220d615f5d52f432447708741a1c8571d93d3d67b229f48f9d91f9.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/n5/cn577wod4hnt5ilfoi2wqa2ani6ai3hpj6oyen4zalxgng7t6uko.py b/progress/SpecForge/cache/compiled_kernels/n5/cn577wod4hnt5ilfoi2wqa2ani6ai3hpj6oyen4zalxgng7t6uko.py new file mode 100644 index 0000000000000000000000000000000000000000..4baace210431414eec5602396486e41ed7c5d6f4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/n5/cn577wod4hnt5ilfoi2wqa2ani6ai3hpj6oyen4zalxgng7t6uko.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/n5/cn5c4s3z5bjrojucpoxdwixxjzfx6k6zimtyptnrf5ztfgbkkb27.py b/progress/SpecForge/cache/compiled_kernels/n5/cn5c4s3z5bjrojucpoxdwixxjzfx6k6zimtyptnrf5ztfgbkkb27.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7d6e1f464ce4d2f3bed45bbb62fbcf8b42fc35 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/n5/cn5c4s3z5bjrojucpoxdwixxjzfx6k6zimtyptnrf5ztfgbkkb27.py @@ -0,0 +1,739 @@ +# AOT ID: ['2_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/c5/cc5oq7jygergpvto6ee2gvnf2mra6awo63pcxqudcd7ytnj7zodx.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=arg8_1] +# %arg14_1 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=arg14_1] +# %arg17_1 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=arg17_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg14_1, %arg17_1, %arg19_1, %arg22_1, %arg24_1, %arg27_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/g4/cg4y3bgvtlp62cdorpd6bc2yltyucyl4uv5uslettv26fielsmud.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + s94 = arg13_1 + s28 = arg15_1 + s4 = arg16_1 + s56 = arg18_1 + s84 = arg20_1 + s53 = arg21_1 + s100 = arg23_1 + s5 = arg25_1 + s10 = arg26_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg14_1, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(arg17_1, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(arg19_1, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(arg22_1, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(arg24_1, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(arg27_1, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg14_1, arg17_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream0) + del arg12_1 + del arg14_1 + del arg17_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1728 + arg1_1 = rand_strided((1, 32, 1728, 128), (7077888, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 1728 + arg3_1 = rand_strided((1, 8, 1728, 128), (1769472, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg4_1 = 1728 + arg5_1 = rand_strided((1, 8, 1728, 128), (1769472, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg6_1 = 14 + arg7_1 = 14 + arg8_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:0', dtype=torch.int32) + arg9_1 = 1728 + arg10_1 = 1728 + arg11_1 = 14 + arg12_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:0', dtype=torch.int32) + arg13_1 = 14 + arg14_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:0', dtype=torch.int32) + arg15_1 = 14 + arg16_1 = 14 + arg17_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:0', dtype=torch.int32) + arg18_1 = 14 + arg19_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:0', dtype=torch.int32) + arg20_1 = 14 + arg21_1 = 14 + arg22_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:0', dtype=torch.int32) + arg23_1 = 14 + arg24_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:0', dtype=torch.int32) + arg25_1 = 14 + arg26_1 = 14 + arg27_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/n5/cn5gxwqnt6gdbefobg2lehdtf25hdqtjwemg2dyag2vfms2qli6c.py b/progress/SpecForge/cache/compiled_kernels/n5/cn5gxwqnt6gdbefobg2lehdtf25hdqtjwemg2dyag2vfms2qli6c.py new file mode 100644 index 0000000000000000000000000000000000000000..91e64cdff7793fcdd3df83e93fb792a8935dfd10 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/n5/cn5gxwqnt6gdbefobg2lehdtf25hdqtjwemg2dyag2vfms2qli6c.py @@ -0,0 +1,707 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dt/cdt7p2o3ok2wfcldv4qsdhjija3xekjiu6ve4x6lyf3bqcrkvbz5.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/2j/c2j3mtk3thi6sn2hxiuhuigjw43spiu74mxdervpgpfrtos7u2qh.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream2) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/n5/cn5ykoaxxefbhtqa6gkemon2itryrhe2jlchrjsmbb5tt7mje2lt.py b/progress/SpecForge/cache/compiled_kernels/n5/cn5ykoaxxefbhtqa6gkemon2itryrhe2jlchrjsmbb5tt7mje2lt.py new file mode 100644 index 0000000000000000000000000000000000000000..df476fa4f5e0214c6f7aa5cf343ffbb882477825 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/n5/cn5ykoaxxefbhtqa6gkemon2itryrhe2jlchrjsmbb5tt7mje2lt.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/na/cnatuyuse5hcbtmmdzshnjxwnhrkwsyeuuwuwsnkokgtrpfnqont.py b/progress/SpecForge/cache/compiled_kernels/na/cnatuyuse5hcbtmmdzshnjxwnhrkwsyeuuwuwsnkokgtrpfnqont.py new file mode 100644 index 0000000000000000000000000000000000000000..44c92ffeecc43a9cf9f300de5fcefcc9a4a292af --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/na/cnatuyuse5hcbtmmdzshnjxwnhrkwsyeuuwuwsnkokgtrpfnqont.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/si/csiqxh5yfqziyhs4zfuiszx7ckahzxji737hjp4j5cz7dsebx4cg.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:5" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:5" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:5" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (768, 768, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 98304*idx_hq + 3145728*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/d3/cd3fyk2an5e67cbkt4zfg2vrfg3ra6vdbivectsfiwzgfqjpyudy.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_5, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_6, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_7, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_8, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_9, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_10, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_11, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 6, 1, 32, stream=stream5) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, 24576, stream=stream5) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/nb/cnbyzg37ql3x6qb2v6iq7zqfdseknzpy3i73cnjs4dnjp2kgvg3f.py b/progress/SpecForge/cache/compiled_kernels/nb/cnbyzg37ql3x6qb2v6iq7zqfdseknzpy3i73cnjs4dnjp2kgvg3f.py new file mode 100644 index 0000000000000000000000000000000000000000..c929a3fd8e42058a0868a4394b990020bc319ea5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nb/cnbyzg37ql3x6qb2v6iq7zqfdseknzpy3i73cnjs4dnjp2kgvg3f.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3735552, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 933888, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 933888, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 912 + ZKV = 1 + KV_LEN = 912 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 116736*idx_hq + 3735552*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ni/cnik5jer2sj77e3dwe5apycrphkqz4kvi3x3p5kyfm4zvmvyf7ek.py b/progress/SpecForge/cache/compiled_kernels/ni/cnik5jer2sj77e3dwe5apycrphkqz4kvi3x3p5kyfm4zvmvyf7ek.py new file mode 100644 index 0000000000000000000000000000000000000000..f58cb92a6ead8ec9b40ce249d4f71dd507f13316 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ni/cnik5jer2sj77e3dwe5apycrphkqz4kvi3x3p5kyfm4zvmvyf7ek.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/ni/eea0457df2a42410d347584a5313b132ddc7294184a64ed7a288299f4c1b907e.best_config b/progress/SpecForge/cache/compiled_kernels/ni/eea0457df2a42410d347584a5313b132ddc7294184a64ed7a288299f4c1b907e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f8775bf9d9c7f9e289b15bffae3aad4d22b82bea --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ni/eea0457df2a42410d347584a5313b132ddc7294184a64ed7a288299f4c1b907e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 84, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py b/progress/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf6f33d25fe50d395ae569a7f3a84752bbaaf83 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nj/cnjtse3xftpnmqvwojj6g7ajl3r3hvxbz3sgyaaznnrxcs7gzj2e.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 124928*idx_hq + 3997696*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/nl/47cabc7a2abcc0fb0c12f26475c60b8bf00047ffa1a14819374f1f116ad3eeb6.best_config b/progress/SpecForge/cache/compiled_kernels/nl/47cabc7a2abcc0fb0c12f26475c60b8bf00047ffa1a14819374f1f116ad3eeb6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c56b3ff6df726fa6b67725165b8989b16d820629 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nl/47cabc7a2abcc0fb0c12f26475c60b8bf00047ffa1a14819374f1f116ad3eeb6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/nl/cnlicf3gxmwqxgcrephkgh25nolfnuz7cungpmokpzy6oauhzf2v.py b/progress/SpecForge/cache/compiled_kernels/nl/cnlicf3gxmwqxgcrephkgh25nolfnuz7cungpmokpzy6oauhzf2v.py new file mode 100644 index 0000000000000000000000000000000000000000..29478151c6a293047711f3c87d242c9574d8e112 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nl/cnlicf3gxmwqxgcrephkgh25nolfnuz7cungpmokpzy6oauhzf2v.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/nn/cnntavd5p4rweamwdxomuln7et4lvnjdxfhlvm7osztfsv3r6ckw.py b/progress/SpecForge/cache/compiled_kernels/nn/cnntavd5p4rweamwdxomuln7et4lvnjdxfhlvm7osztfsv3r6ckw.py new file mode 100644 index 0000000000000000000000000000000000000000..01293694f4e0380ac53da93f2b107f3faed29672 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nn/cnntavd5p4rweamwdxomuln7et4lvnjdxfhlvm7osztfsv3r6ckw.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/no/cnoqzqsjmwsbqehq6fim3o2grhibpyaizrdu6h7ihrfnroahemhs.py b/progress/SpecForge/cache/compiled_kernels/no/cnoqzqsjmwsbqehq6fim3o2grhibpyaizrdu6h7ihrfnroahemhs.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c914cff0ce123e018bf0b5351ccf026eabfc7a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/no/cnoqzqsjmwsbqehq6fim3o2grhibpyaizrdu6h7ihrfnroahemhs.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/fs/cfs2of6ygpacl7vcs74vy37xwnw4avrqcbqon6ht56dira6khkkl.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:2" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:2" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zi/czisxx2rasv3jxm6nvp2yhkwlbob3ysmyfupuctvxqfwp52ekm3f.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_15, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_16, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_17, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_18, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_19, (1, 1, 11, 11), (121, 121, 11, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream2) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1312 + primals_2 = rand_strided((1, 32, 1312, 128), (5373952, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = 1312 + primals_4 = rand_strided((1, 8, 1312, 128), (1343488, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_5 = 1312 + primals_6 = rand_strided((1, 8, 1312, 128), (1343488, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = 11 + primals_8 = 11 + primals_9 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_10 = 1312 + primals_11 = 1312 + primals_12 = 11 + primals_13 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/no/cnox5iwofyl7r5fpqptl32ehphom56wvhety7wr52h6uwdknqaro.py b/progress/SpecForge/cache/compiled_kernels/no/cnox5iwofyl7r5fpqptl32ehphom56wvhety7wr52h6uwdknqaro.py new file mode 100644 index 0000000000000000000000000000000000000000..855bf6d53cb7a7376beb10390faa3aceb7ac42ff --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/no/cnox5iwofyl7r5fpqptl32ehphom56wvhety7wr52h6uwdknqaro.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sw/cswwaepmaupqvarnmz6ncii4cmhronmafzughe6w7yjkeroyh33w.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (768, 768, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 98304*idx_hq + 3145728*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/am/camey3wf35nw3zclfvahr3a2qldzaw7bu3nbioja6kav4yh62xeg.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_5, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_6, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_7, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_8, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_9, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_10, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_11, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 6, 1, 32, stream=stream1) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, 24576, stream=stream1) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/nr/cnrn743htvwkbpa62ur5yvwkjuqstpxlpvja3asvt7izdjyk7jet.py b/progress/SpecForge/cache/compiled_kernels/nr/cnrn743htvwkbpa62ur5yvwkjuqstpxlpvja3asvt7izdjyk7jet.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d18e8a5b9e03bac35a96cfebc95a2f27821a4d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nr/cnrn743htvwkbpa62ur5yvwkjuqstpxlpvja3asvt7izdjyk7jet.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sc/csc67q546j7muhvoccled2oa4u3qywv2ugybi3n244etkkzozhm2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1104, 128][4521984, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1104, 1104, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4521984, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1130496, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1130496, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1104 + ZKV = 1 + KV_LEN = 1104 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 141312*idx_hq + 4521984*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zr/czr6mneyujqwtcnujzpzrifkj4z36xtvmjyymecmrkgwf7765fr3.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 282624}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 35328 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1104, 128), (4521984, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_5, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_6, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_7, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_8, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_9, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_10, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_11, (1, 1, 9, 9), (81, 81, 9, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, 1104), (35328, 1104, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1104), (35328, 1104, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1104, 128), (4521984, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 9, 1, 32, stream=stream3) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_poi_fused_mul_1.run(buf0, buf5, 35328, stream=stream3) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/nt/cntrepsg3bvsfcqbuics5zxl5te76fcgzahhb7rkhknbzmtgn3h6.py b/progress/SpecForge/cache/compiled_kernels/nt/cntrepsg3bvsfcqbuics5zxl5te76fcgzahhb7rkhknbzmtgn3h6.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a65e9dcdf7b06a1b91672ebbc3fe6ad2108b20 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nt/cntrepsg3bvsfcqbuics5zxl5te76fcgzahhb7rkhknbzmtgn3h6.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/nz/cnz3o2q2vb2ti2gyyevhvaopio72peoav2vt76s5rmoxvuy4zs6c.py b/progress/SpecForge/cache/compiled_kernels/nz/cnz3o2q2vb2ti2gyyevhvaopio72peoav2vt76s5rmoxvuy4zs6c.py new file mode 100644 index 0000000000000000000000000000000000000000..385867c310ed2337ad32a384e48d9b6d04c05455 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/nz/cnz3o2q2vb2ti2gyyevhvaopio72peoav2vt76s5rmoxvuy4zs6c.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/o5/co5tyobl4yghx6xegllqe7pgubmmdmajgrzhb45ymyxgtk6vy5h5.py b/progress/SpecForge/cache/compiled_kernels/o5/co5tyobl4yghx6xegllqe7pgubmmdmajgrzhb45ymyxgtk6vy5h5.py new file mode 100644 index 0000000000000000000000000000000000000000..64a70b237dfcf4579ff811803cf208c0a6c4df12 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/o5/co5tyobl4yghx6xegllqe7pgubmmdmajgrzhb45ymyxgtk6vy5h5.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/o5/co5ymfefdgahcwkjo4bf6n435ycdi63hj2qw5lci4mp5y23gvenx.py b/progress/SpecForge/cache/compiled_kernels/o5/co5ymfefdgahcwkjo4bf6n435ycdi63hj2qw5lci4mp5y23gvenx.py new file mode 100644 index 0000000000000000000000000000000000000000..69cca7a2d5b30dcbd8462e2003cb5ef22a6a4e66 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/o5/co5ymfefdgahcwkjo4bf6n435ycdi63hj2qw5lci4mp5y23gvenx.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/od/40309b9cf90ae0c8127ebbeb2aa91d0a17d0f2643fc27f76b9c3c4c82192c151.best_config b/progress/SpecForge/cache/compiled_kernels/od/40309b9cf90ae0c8127ebbeb2aa91d0a17d0f2643fc27f76b9c3c4c82192c151.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c56b3ff6df726fa6b67725165b8989b16d820629 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/od/40309b9cf90ae0c8127ebbeb2aa91d0a17d0f2643fc27f76b9c3c4c82192c151.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/od/codhw4octxotjgzwvuzdm3pslxxwo6zefy26otese4tvwxgfy45f.py b/progress/SpecForge/cache/compiled_kernels/od/codhw4octxotjgzwvuzdm3pslxxwo6zefy26otese4tvwxgfy45f.py new file mode 100644 index 0000000000000000000000000000000000000000..b645c94477877ae9bbe966eb718534497dc9f1d8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/od/codhw4octxotjgzwvuzdm3pslxxwo6zefy26otese4tvwxgfy45f.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/oh/cohge6z4ztoq2k5jikh42tczggffkgdsamlr54jbnxsw3sl57kc3.py b/progress/SpecForge/cache/compiled_kernels/oh/cohge6z4ztoq2k5jikh42tczggffkgdsamlr54jbnxsw3sl57kc3.py new file mode 100644 index 0000000000000000000000000000000000000000..c709228b047a6b9117e9929fe6f0bef80963c6fc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/oh/cohge6z4ztoq2k5jikh42tczggffkgdsamlr54jbnxsw3sl57kc3.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 786432, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3145728, 98304, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3145728, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 98304*off_hkv + 786432*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/oi/coi77xrrywawfst2wdbmeyzt2efi476di4gsrg2meicbywnouebm.py b/progress/SpecForge/cache/compiled_kernels/oi/coi77xrrywawfst2wdbmeyzt2efi476di4gsrg2meicbywnouebm.py new file mode 100644 index 0000000000000000000000000000000000000000..d336227f16bab3b18b5fd7c8a39bddd3d6739926 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/oi/coi77xrrywawfst2wdbmeyzt2efi476di4gsrg2meicbywnouebm.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 13 + stride_q_idx_h = 169 + stride_q_idx_n = 13 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ol/col5xr7m24v6kb66aa3rur4q4bjmalhin4fevhxsdzasef34d74q.py b/progress/SpecForge/cache/compiled_kernels/ol/col5xr7m24v6kb66aa3rur4q4bjmalhin4fevhxsdzasef34d74q.py new file mode 100644 index 0000000000000000000000000000000000000000..a06d291b49149bc28b087d464dcfa6491369356f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ol/col5xr7m24v6kb66aa3rur4q4bjmalhin4fevhxsdzasef34d74q.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/os/coscuyietmpq5e7tohbdukj24ad47xgkkqif3bm5qyvq7drfabzi.py b/progress/SpecForge/cache/compiled_kernels/os/coscuyietmpq5e7tohbdukj24ad47xgkkqif3bm5qyvq7drfabzi.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ab0fe64fcec3b5d8008ae425e974b2b8e94c44 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/os/coscuyietmpq5e7tohbdukj24ad47xgkkqif3bm5qyvq7drfabzi.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 196608, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 49152, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 49152, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 49152, 1536, 48, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 49152, 1536, 48, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = 48 + KV_LEN = 48 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t + 6291456*idx_z + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/os/cosyzqwfhjripwes7gofi4cc5yi235vjdqacukvxrwfr7zj2vsg7.py b/progress/SpecForge/cache/compiled_kernels/os/cosyzqwfhjripwes7gofi4cc5yi235vjdqacukvxrwfr7zj2vsg7.py new file mode 100644 index 0000000000000000000000000000000000000000..6dbf108b69b300a584d3d89f9eb557ed5348d863 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/os/cosyzqwfhjripwes7gofi4cc5yi235vjdqacukvxrwfr7zj2vsg7.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/ot/cotrsbintrwogfyqqgnphf2yeyr6yfylbjbvpw7omv42gwfsf4bc.py b/progress/SpecForge/cache/compiled_kernels/ot/cotrsbintrwogfyqqgnphf2yeyr6yfylbjbvpw7omv42gwfsf4bc.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2893578e9d4d3162358927840ed619bc272d14 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ot/cotrsbintrwogfyqqgnphf2yeyr6yfylbjbvpw7omv42gwfsf4bc.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 380928}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 47616 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/ot/d144752e97a140328179a677fb16f0e3b5505ff34fca91e834ade58acf1600ff.best_config b/progress/SpecForge/cache/compiled_kernels/ot/d144752e97a140328179a677fb16f0e3b5505ff34fca91e834ade58acf1600ff.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cfea0cd07ad3400bd4f59924e059c40c2131dc7b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ot/d144752e97a140328179a677fb16f0e3b5505ff34fca91e834ade58acf1600ff.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "3POBVEMFF6HCRKLIB7YNNJRR2EJ2EKYOLLPRWY76J74FBEFLCFGA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ov/cov2xtsbfix5lto5mazwkmpgzwhukonzh73gncg4gr5xfm6dtfd7.py b/progress/SpecForge/cache/compiled_kernels/ov/cov2xtsbfix5lto5mazwkmpgzwhukonzh73gncg4gr5xfm6dtfd7.py new file mode 100644 index 0000000000000000000000000000000000000000..3859e225965269200370521771ba98ef42a2ae4f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ov/cov2xtsbfix5lto5mazwkmpgzwhukonzh73gncg4gr5xfm6dtfd7.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/oy/82551e73adfef1714f981bdc656d889767ac348a29917bcb08707930199a3b9e.best_config b/progress/SpecForge/cache/compiled_kernels/oy/82551e73adfef1714f981bdc656d889767ac348a29917bcb08707930199a3b9e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..455d0b61f7b8430200b6693df690e504b02408fb --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/oy/82551e73adfef1714f981bdc656d889767ac348a29917bcb08707930199a3b9e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 4, "R0_BLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 39, "triton_cache_hash": "OCLTNORRQU6QEOK4RR5KGKQE2JLQZNWPRX3GINXQOL62XTLAX3YA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/oy/coyqzoovg3nwccfy2s7q4idglwhoj44wvrhlg42qsg4b665rrwuo.py b/progress/SpecForge/cache/compiled_kernels/oy/coyqzoovg3nwccfy2s7q4idglwhoj44wvrhlg42qsg4b665rrwuo.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8f5d1a6a24bd996a51145497d8de9fbb01ba1e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/oy/coyqzoovg3nwccfy2s7q4idglwhoj44wvrhlg42qsg4b665rrwuo.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 294912, 'r0_': 12582912}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 24576 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 768) + x1 = xindex // 768 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, None) diff --git a/progress/SpecForge/cache/compiled_kernels/p3/cp3ugynythnqlux2v4dqu5lutbwlwb4ms5jfijbkvbbzrshao7pv.py b/progress/SpecForge/cache/compiled_kernels/p3/cp3ugynythnqlux2v4dqu5lutbwlwb4ms5jfijbkvbbzrshao7pv.py new file mode 100644 index 0000000000000000000000000000000000000000..338de56afdd77838376899f0e19a65c8ace01154 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/p3/cp3ugynythnqlux2v4dqu5lutbwlwb4ms5jfijbkvbbzrshao7pv.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ea/ceajfbjjrurt6t4led3jw4lyfkf23ku2btxjrnxzwjgjo26pomeq.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 656, 128][2686976, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 656, 128][671744, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 656, 128][671744, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 656][20992, 656, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 656][20992, 656, 1]cuda:7" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:7" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:7" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (656, 656, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2686976, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 671744, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 671744, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 656 + ZKV = 1 + KV_LEN = 656 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 83968*idx_hq + 2686976*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wc/cwcq2mf7cnltkrq7krmalipwcxj2b46f7bdyonvnqsqw5uhfomts.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 656][20992, 656, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 167936}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 20992 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 656, 128), (2686976, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 656, 128), (671744, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 656, 128), (671744, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg4_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg5_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg6_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg7_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg8_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg9_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg10_1, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, 656), (20992, 656, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 656), (20992, 656, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 656, 128), (2686976, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 6, 1, 32, stream=stream7) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, 20992, stream=stream7) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 656, 128), (2686976, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 656, 128), (671744, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 656, 128), (671744, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/pn/cpn5szp5vkygpknzuabawp6qrjqxkaoy4kmmo2g5d7gltrujss2t.py b/progress/SpecForge/cache/compiled_kernels/pn/cpn5szp5vkygpknzuabawp6qrjqxkaoy4kmmo2g5d7gltrujss2t.py new file mode 100644 index 0000000000000000000000000000000000000000..b452268ed0bc4f5d4afceae7587d90c6b32a3930 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/pn/cpn5szp5vkygpknzuabawp6qrjqxkaoy4kmmo2g5d7gltrujss2t.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ps/b738474c858c2490c89271cd2af5c91fec8bf2f3d04e03aae8a57f2d0c95795e.best_config b/progress/SpecForge/cache/compiled_kernels/ps/b738474c858c2490c89271cd2af5c91fec8bf2f3d04e03aae8a57f2d0c95795e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2df7a8a1ca0ef6a0f733580abfb725612f65a921 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ps/b738474c858c2490c89271cd2af5c91fec8bf2f3d04e03aae8a57f2d0c95795e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 4, "R0_BLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 40, "triton_cache_hash": "AIFEPBEOFTWCYZ4HKMAVVJSWB23JBZPVI2N7V4RDBJPWEXDWUG7A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py b/progress/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py new file mode 100644 index 0000000000000000000000000000000000000000..bc91d4d4714cb5b7f5c0cf4436ef04eae1b97818 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 374784, 'r0_': 15990784}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 31232 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 976) + x1 = xindex // 976 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/px/cpxzws5bmrulopqu5y7bt3xc26jz2bgfnb45ptng2di4kjf2aezq.py b/progress/SpecForge/cache/compiled_kernels/px/cpxzws5bmrulopqu5y7bt3xc26jz2bgfnb45ptng2di4kjf2aezq.py new file mode 100644 index 0000000000000000000000000000000000000000..4f73d076974ffaf4a2d030ed7512b0d9d9b993da --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/px/cpxzws5bmrulopqu5y7bt3xc26jz2bgfnb45ptng2di4kjf2aezq.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2686976, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 671744, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 671744, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 656 + ZKV = 1 + KV_LEN = 656 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 83968*idx_hq + 2686976*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/q5/cq5goucsdqab6lsesb2pwahdi6rt55rvaw3echwbytpaeygwupgh.py b/progress/SpecForge/cache/compiled_kernels/q5/cq5goucsdqab6lsesb2pwahdi6rt55rvaw3echwbytpaeygwupgh.py new file mode 100644 index 0000000000000000000000000000000000000000..74a8467b8c4182485abc733fddf8df93994cb990 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/q5/cq5goucsdqab6lsesb2pwahdi6rt55rvaw3echwbytpaeygwupgh.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/q7/cq7b5ms44hcmrr7feriutowpmikxk2ctxzdbbsr7j7kyqd2c4b77.py b/progress/SpecForge/cache/compiled_kernels/q7/cq7b5ms44hcmrr7feriutowpmikxk2ctxzdbbsr7j7kyqd2c4b77.py new file mode 100644 index 0000000000000000000000000000000000000000..9676e7ee7be86dd39f35da0149390e78570e1986 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/q7/cq7b5ms44hcmrr7feriutowpmikxk2ctxzdbbsr7j7kyqd2c4b77.py @@ -0,0 +1,707 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5b/c5blvz5sxoj2veuexokuub2zm2pg4l2nqbbny4rr2jhsiiyw6njy.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wh/cwhdpqc7err4iroi3t4bxgkbxdgw2bnqeqr4zt5upvah2ilnbh2e.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream7) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream7) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/qi/cqiq4qr54gju3rsitwenjmqmqmbcgp53lgxmlnxx4r5zf53dvozc.py b/progress/SpecForge/cache/compiled_kernels/qi/cqiq4qr54gju3rsitwenjmqmqmbcgp53lgxmlnxx4r5zf53dvozc.py new file mode 100644 index 0000000000000000000000000000000000000000..c97035d6fbc930a3df1c8fd22229ffcff8f12c8d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/qi/cqiq4qr54gju3rsitwenjmqmqmbcgp53lgxmlnxx4r5zf53dvozc.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iy/ciycjvyaco6b3mtyliulibjjjnk2hxipagfayteqs2axxcgg4f2h.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:1" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:1" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:1" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:1" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:1" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:1" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream1) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 704 + primals_11 = 704 + primals_7 = 6 + primals_8 = 6 + primals_12 = 6 + primals_14 = 6 + primals_16 = 6 + primals_17 = 6 + primals_19 = 6 + primals_22 = 6 + primals_21 = 6 + primals_24 = 6 + primals_27 = 6 + primals_26 = 6 + primals_2 = rand_strided((1, 32, 704, 128), (2883584, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 704, 128), (720896, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 704, 128), (720896, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((1, 32, 704, 128), (2883584, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 704), (22528, 704, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 704, 128), (2883584, 90112, 128, 1), device='cuda:1', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 704), (22528, 704, 1), device='cuda:1', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py b/progress/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py new file mode 100644 index 0000000000000000000000000000000000000000..919ed21e1ba855b82b5b099d17dfd85bd02a4017 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 79872*idx_hq + 2555904*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/qo/cqo4mcqb5vtzpwrhroi4ndfou3vehwggsyc4o4mdxrxslpnjnayx.py b/progress/SpecForge/cache/compiled_kernels/qo/cqo4mcqb5vtzpwrhroi4ndfou3vehwggsyc4o4mdxrxslpnjnayx.py new file mode 100644 index 0000000000000000000000000000000000000000..9e55528035ae9d6b47d453348661859421646949 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/qo/cqo4mcqb5vtzpwrhroi4ndfou3vehwggsyc4o4mdxrxslpnjnayx.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4b/c4bs4faiglroap6s5exqskt27kcrendxbin52u3h5veyhjham26s.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gd/cgdtsnyqm2fsveto5s6yskc7delp34hsyteokw2zkmblhdxv622h.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:3" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:3" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:3" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:3" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:3" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:3" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream3) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 928 + primals_11 = 928 + primals_7 = 8 + primals_8 = 8 + primals_12 = 8 + primals_14 = 8 + primals_16 = 8 + primals_17 = 8 + primals_19 = 8 + primals_22 = 8 + primals_21 = 8 + primals_24 = 8 + primals_27 = 8 + primals_26 = 8 + primals_2 = rand_strided((1, 32, 928, 128), (3801088, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 928, 128), (950272, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:3', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((1, 32, 928, 128), (3801088, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 928), (29696, 928, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 928, 128), (3801088, 118784, 128, 1), device='cuda:3', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 928), (29696, 928, 1), device='cuda:3', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/qp/cqpspcatxe6qlrdn4i2adjktoawycld6ckjcptfj4kefnavpletp.py b/progress/SpecForge/cache/compiled_kernels/qp/cqpspcatxe6qlrdn4i2adjktoawycld6ckjcptfj4kefnavpletp.py new file mode 100644 index 0000000000000000000000000000000000000000..955b81d23cca5bb3f24c9e9c7901a1a025651f7c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/qp/cqpspcatxe6qlrdn4i2adjktoawycld6ckjcptfj4kefnavpletp.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5k/c5kiw545j4jh6gzwff5a5mpcfuhjlcvxbrhzjxxfwcjggt7l6rju.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ov/cov2xtsbfix5lto5mazwkmpgzwhukonzh73gncg4gr5xfm6dtfd7.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:5" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:5" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:5" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:5" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:5" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:5" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream5) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 496 + primals_11 = 496 + primals_7 = 4 + primals_8 = 4 + primals_12 = 4 + primals_14 = 4 + primals_16 = 4 + primals_17 = 4 + primals_19 = 4 + primals_22 = 4 + primals_21 = 4 + primals_24 = 4 + primals_27 = 4 + primals_26 = 4 + primals_2 = rand_strided((1, 32, 496, 128), (2031616, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 496, 128), (507904, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 496, 128), (507904, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((1, 32, 496, 128), (2031616, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 496), (15872, 496, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 496, 128), (2031616, 63488, 128, 1), device='cuda:5', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 496), (15872, 496, 1), device='cuda:5', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/qt/38c0339fd9d8f3409013bdce3fdd3b45bb41528f46b543b7d1d9db29a478dd5f.best_config b/progress/SpecForge/cache/compiled_kernels/qt/38c0339fd9d8f3409013bdce3fdd3b45bb41528f46b543b7d1d9db29a478dd5f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3514cd4b7efbcfec5a9027a69143f1b0c3ed176a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/qt/38c0339fd9d8f3409013bdce3fdd3b45bb41528f46b543b7d1d9db29a478dd5f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/qt/cqtdareqz6tivounhz5hfeqyiotaxvjxc2v3hngtznpiz5nqvksa.py b/progress/SpecForge/cache/compiled_kernels/qt/cqtdareqz6tivounhz5hfeqyiotaxvjxc2v3hngtznpiz5nqvksa.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2b11376c510d529e4867c9414df4cb82a956ec --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/qt/cqtdareqz6tivounhz5hfeqyiotaxvjxc2v3hngtznpiz5nqvksa.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/r5/cr5gimrtqf5gqmgkgirz6nebp4ijwtbzhmqlsdgjqda33dr43qre.py b/progress/SpecForge/cache/compiled_kernels/r5/cr5gimrtqf5gqmgkgirz6nebp4ijwtbzhmqlsdgjqda33dr43qre.py new file mode 100644 index 0000000000000000000000000000000000000000..a41a461906d31a8fc90234abafd48576a9d51c42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/r5/cr5gimrtqf5gqmgkgirz6nebp4ijwtbzhmqlsdgjqda33dr43qre.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 36864, 'r0_': 0}} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1536 + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 1536*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 1536*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x0), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/r5/ead326e7759c96061cc1c41fb4e42ef3b5dd189f6691a4507dc929de78e15b5c.best_config b/progress/SpecForge/cache/compiled_kernels/r5/ead326e7759c96061cc1c41fb4e42ef3b5dd189f6691a4507dc929de78e15b5c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..7b6d9a1b94545447e9df6c1447ac697425a25868 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/r5/ead326e7759c96061cc1c41fb4e42ef3b5dd189f6691a4507dc929de78e15b5c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 32, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 76, "triton_cache_hash": "ZGUZG4PPPADQSTHIUCWKTER32EZVBCECRNJ4C65FS7KX5AOLZL5Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/r7/cr7yqbjyejqfa2ld73pqqr2tp7g6ybrhs4u34xt5s4cusntoze4a.py b/progress/SpecForge/cache/compiled_kernels/r7/cr7yqbjyejqfa2ld73pqqr2tp7g6ybrhs4u34xt5s4cusntoze4a.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfbb7d378030ad112403946cfe42b65eb6c7f60 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/r7/cr7yqbjyejqfa2ld73pqqr2tp7g6ybrhs4u34xt5s4cusntoze4a.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 294912, 'r0_': 12582912}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 24576 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 768) + x1 = xindex // 768 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, None) diff --git a/progress/SpecForge/cache/compiled_kernels/r7/fb64a749c45cc85864be8fdd78d3e2ddb1af4812a2f3c8ac1b1cd4cfd4ee8558.best_config b/progress/SpecForge/cache/compiled_kernels/r7/fb64a749c45cc85864be8fdd78d3e2ddb1af4812a2f3c8ac1b1cd4cfd4ee8558.best_config new file mode 100644 index 0000000000000000000000000000000000000000..91f7d3b60741c38f7cd773461ac0c17d0b164739 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/r7/fb64a749c45cc85864be8fdd78d3e2ddb1af4812a2f3c8ac1b1cd4cfd4ee8558.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 39, "triton_cache_hash": "P2OQFKFK2VYOLKX3ZFEIJCKRP37BJTFKDRI2TSDR5SQJLAYZETYQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/rg/59ac99122cc3b5c2042e5d215ef45fb00654e730d721a1488c0f62a7433b7445.best_config b/progress/SpecForge/cache/compiled_kernels/rg/59ac99122cc3b5c2042e5d215ef45fb00654e730d721a1488c0f62a7433b7445.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d65196b98f27b9ceb4228a80b253cf313fff4a5a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rg/59ac99122cc3b5c2042e5d215ef45fb00654e730d721a1488c0f62a7433b7445.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "JIKOMHE3JAKFHZL5B4DIF7M3OFXH2UR3ALMAT34JQOPMCUY6YVJA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/rg/crgxcf2kpnvse2bk65ukwreddvsjjcwhijeklqq7egtduqxsqwue.py b/progress/SpecForge/cache/compiled_kernels/rg/crgxcf2kpnvse2bk65ukwreddvsjjcwhijeklqq7egtduqxsqwue.py new file mode 100644 index 0000000000000000000000000000000000000000..a119474b61e31a49994e70317ef0faeabaf40d38 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rg/crgxcf2kpnvse2bk65ukwreddvsjjcwhijeklqq7egtduqxsqwue.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 159744}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 19968 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/rh/765eaa0aa11ed7bec55ff68482aaa424aaca68cb3a58bca6949cf199eaada52f.best_config b/progress/SpecForge/cache/compiled_kernels/rh/765eaa0aa11ed7bec55ff68482aaa424aaca68cb3a58bca6949cf199eaada52f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rh/765eaa0aa11ed7bec55ff68482aaa424aaca68cb3a58bca6949cf199eaada52f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/rh/bf6e21a1e348878aa4f6ee1826b2f30ec7cda90c377eed1d83fdc8fbef404d3b.best_config b/progress/SpecForge/cache/compiled_kernels/rh/bf6e21a1e348878aa4f6ee1826b2f30ec7cda90c377eed1d83fdc8fbef404d3b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..56a643a6c205f101d22e1d65692ff2fcac0c595e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rh/bf6e21a1e348878aa4f6ee1826b2f30ec7cda90c377eed1d83fdc8fbef404d3b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 8, "num_stages": 1, "configs_hash": "22b8c9e89632e6687ce26aaad980a76bbf5ee683fff317f3a6d7989c7528ff63", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "WJHIHLPATQZBKSQZSWJ5BD3ABYGFUF3YD6VF633RGCNWMMKVXCCA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/rh/crhihk6v6nqpqusr4vp7qqztet42ugdhn524fpi6y32vdhe2bair.py b/progress/SpecForge/cache/compiled_kernels/rh/crhihk6v6nqpqusr4vp7qqztet42ugdhn524fpi6y32vdhe2bair.py new file mode 100644 index 0000000000000000000000000000000000000000..b595c65deec1b918403644452699f9772308e9e4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rh/crhihk6v6nqpqusr4vp7qqztet42ugdhn524fpi6y32vdhe2bair.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/rh/crhmk736vsio3af666cgqjvjsg6hqfa35qx276udet5tizr6uf35.py b/progress/SpecForge/cache/compiled_kernels/rh/crhmk736vsio3af666cgqjvjsg6hqfa35qx276udet5tizr6uf35.py new file mode 100644 index 0000000000000000000000000000000000000000..bba58da9a7c23a89c96a9483353c5ce5568965a6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rh/crhmk736vsio3af666cgqjvjsg6hqfa35qx276udet5tizr6uf35.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/rh/crhomk4dbfrkitfzvo6fho5x4nowgfrwmpv4nnpqbiigzvcf6dm5.py b/progress/SpecForge/cache/compiled_kernels/rh/crhomk4dbfrkitfzvo6fho5x4nowgfrwmpv4nnpqbiigzvcf6dm5.py new file mode 100644 index 0000000000000000000000000000000000000000..84b040bdcd56ed90940646d287805b728dd82725 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rh/crhomk4dbfrkitfzvo6fho5x4nowgfrwmpv4nnpqbiigzvcf6dm5.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/rj/crjnxfiu4q7l2yt2toklk4exbw7ob3o5uej7oia6rdhsfvegvdo6.py b/progress/SpecForge/cache/compiled_kernels/rj/crjnxfiu4q7l2yt2toklk4exbw7ob3o5uej7oia6rdhsfvegvdo6.py new file mode 100644 index 0000000000000000000000000000000000000000..757997377e85294e63059a78b037efcf63e8ca18 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rj/crjnxfiu4q7l2yt2toklk4exbw7ob3o5uej7oia6rdhsfvegvdo6.py @@ -0,0 +1,707 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l2/cl2dotglxdqten7ibak6ugqut6sv6idwda7vpobhf4qjozvyezam.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ro/crocdjfs7cfxfmy2gaelowllr5hrdwvl3jmen62vrsdfh5kpbtwa.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream1) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream1) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/rk/crk3vxt3qkwpuz7ytaltbszcyjqibnls6m2evah6ox3kh7acqtmx.py b/progress/SpecForge/cache/compiled_kernels/rk/crk3vxt3qkwpuz7ytaltbszcyjqibnls6m2evah6ox3kh7acqtmx.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9dca6e3ed7fdd653a2bd876c790c6b3e979117 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rk/crk3vxt3qkwpuz7ytaltbszcyjqibnls6m2evah6ox3kh7acqtmx.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gb/cgb3qo2uhbvd22jllo3pcd4qrm2qklyyeuprletllykljldff6hp.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1712, 1712, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 219136*idx_hq + 7012352*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/mk/cmks4hkjvgdwntls2t5soqtos7ncc2vfg7e7gzbvmtvbumfvqdyc.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 438272}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 54784 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1712, 128), (7012352, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1712, 128), (1753088, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1712, 128), (1753088, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(primals_5, (1, 1, 14), (14, 14, 1)) + assert_size_stride(primals_6, (1, 1, 14), (14, 14, 1)) + assert_size_stride(primals_7, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(primals_8, (1, 1, 14), (14, 14, 1)) + assert_size_stride(primals_9, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(primals_10, (1, 1, 14), (14, 14, 1)) + assert_size_stride(primals_11, (1, 1, 14, 14), (196, 196, 14, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1712, 128), (7012352, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 14, 1, 32, stream=stream2) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, 54784, stream=stream2) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/rk/crko3bgr4fgl6gq6tpipoxlq2phtijfzeyn5k4s34g4y5idh4jrt.py b/progress/SpecForge/cache/compiled_kernels/rk/crko3bgr4fgl6gq6tpipoxlq2phtijfzeyn5k4s34g4y5idh4jrt.py new file mode 100644 index 0000000000000000000000000000000000000000..55d6869446e147356813d2f2ea358219381b566d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rk/crko3bgr4fgl6gq6tpipoxlq2phtijfzeyn5k4s34g4y5idh4jrt.py @@ -0,0 +1,1019 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream6) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 128 + primals_9 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 128, 128), (524288, 16384, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 128), (4096, 128, 1), device='cuda:6', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/rm/crmcxgaaazqrlmedkduo2v2y2nubqbvhg4ohmkpuxanwapbiwfvp.py b/progress/SpecForge/cache/compiled_kernels/rm/crmcxgaaazqrlmedkduo2v2y2nubqbvhg4ohmkpuxanwapbiwfvp.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd1649d8a7788aaf7e8aa1c83f8028371fa671d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rm/crmcxgaaazqrlmedkduo2v2y2nubqbvhg4ohmkpuxanwapbiwfvp.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ro/1ae12e42fd022adc2867c37d98888c972c6e708c9b62d244a0dcad4e8f5e2543.best_config b/progress/SpecForge/cache/compiled_kernels/ro/1ae12e42fd022adc2867c37d98888c972c6e708c9b62d244a0dcad4e8f5e2543.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ro/1ae12e42fd022adc2867c37d98888c972c6e708c9b62d244a0dcad4e8f5e2543.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ro/crocdjfs7cfxfmy2gaelowllr5hrdwvl3jmen62vrsdfh5kpbtwa.py b/progress/SpecForge/cache/compiled_kernels/ro/crocdjfs7cfxfmy2gaelowllr5hrdwvl3jmen62vrsdfh5kpbtwa.py new file mode 100644 index 0000000000000000000000000000000000000000..955dac341f63356d4493d2252aebdaa22e289e9f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ro/crocdjfs7cfxfmy2gaelowllr5hrdwvl3jmen62vrsdfh5kpbtwa.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/rv/crvrpwzjooqpmkmke5agssyajlkpcwjjy7wqqe3z3bqap2dgpluz.py b/progress/SpecForge/cache/compiled_kernels/rv/crvrpwzjooqpmkmke5agssyajlkpcwjjy7wqqe3z3bqap2dgpluz.py new file mode 100644 index 0000000000000000000000000000000000000000..f61871def2a10a882143dc2e88b7c6be853d0b7f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/rv/crvrpwzjooqpmkmke5agssyajlkpcwjjy7wqqe3z3bqap2dgpluz.py @@ -0,0 +1,1018 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/r7/cr7yqbjyejqfa2ld73pqqr2tp7g6ybrhs4u34xt5s4cusntoze4a.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 768, 128][3145728, 98304, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (768, 768, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 294912, 'r0_': 12582912}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 24576 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 768) + x1 = xindex // 768 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xr/cxr7lbwuek4pgt6vjfn4e5p45m7ggupaspkah63tfuruz7jsob2z.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 768, 128][3145728, 98304, 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_8 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:1" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:1" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (768, 768, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 786432, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3145728, 98304, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3145728, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 6 + stride_q_idx_h = 36 + stride_q_idx_n = 6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 98304*off_hkv + 786432*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 768 + KV_LEN = 768 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_5, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_6, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_7, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_8, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_9, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_10, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_11, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(getitem, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 768), (24576, 768, 1)) + assert_size_stride(tangents_1, (1, 32, 768, 128), (3145728, 98304, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 768), (24576, 768, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 24576, 128, stream=stream1) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 768, 128), (786432, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 768, 128), (786432, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 30, 1, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:1', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 768), (24576, 768, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 768, 128), (3145728, 98304, 128, 1), device='cuda:1', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 768), (24576, 768, 1), device='cuda:1', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/s2/cs2ckboafjbrrbo5wr4xu4knlgkjyyoeypuyz3hhzzdabjrhvnjw.py b/progress/SpecForge/cache/compiled_kernels/s2/cs2ckboafjbrrbo5wr4xu4knlgkjyyoeypuyz3hhzzdabjrhvnjw.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ead64cba3162dc0dfc58da30efc0b4a597a40d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/s2/cs2ckboafjbrrbo5wr4xu4knlgkjyyoeypuyz3hhzzdabjrhvnjw.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/s3/cs3m2wyjgyemkpwwhyke4e37aca3vhvlm62wgavg7wpnf66i7mlw.py b/progress/SpecForge/cache/compiled_kernels/s3/cs3m2wyjgyemkpwwhyke4e37aca3vhvlm62wgavg7wpnf66i7mlw.py new file mode 100644 index 0000000000000000000000000000000000000000..72f8941cb58a31c36ac5dc6ab3336b2463f51443 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/s3/cs3m2wyjgyemkpwwhyke4e37aca3vhvlm62wgavg7wpnf66i7mlw.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ld/cldwkxdv7izigkhcjor2j2snktzj2wsktvhftnyigkxfdpjoqggz.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1360, 128][5570560, 128, 4096, 1]cuda:5" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1360, 128][1392640, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1360, 128][1392640, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1360][43520, 1360, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1360][43520, 1360, 1]cuda:5" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:5" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:5" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:5" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1360, 1360, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 5570560, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1392640, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1392640, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1360 + ZKV = 1 + KV_LEN = 1360 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 11 + stride_kv_idx_h = 121 + stride_kv_idx_m = 11 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 174080*idx_hq + 5570560*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3b/c3bqw7dk7k6dcdrp3ycrthotye7y6zb26752jl4lwmfgaybpvr6y.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1360][43520, 1360, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 348160}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 43520 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1360, 128), (5570560, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1360, 128), (1392640, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1360, 128), (1392640, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 11), (11, 11, 1)) + assert_size_stride(arg4_1, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(arg5_1, (1, 1, 11), (11, 11, 1)) + assert_size_stride(arg6_1, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(arg7_1, (1, 1, 11), (11, 11, 1)) + assert_size_stride(arg8_1, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(arg9_1, (1, 1, 11), (11, 11, 1)) + assert_size_stride(arg10_1, (1, 1, 11, 11), (121, 121, 11, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, 1360), (43520, 1360, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1360), (43520, 1360, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1360, 128), (5570560, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 11, 1, 32, stream=stream5) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, 43520, stream=stream5) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1360, 128), (5570560, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1360, 128), (1392640, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1360, 128), (1392640, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:5', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:5', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:5', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:5', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:5', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:5', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:5', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/s6/cs65dxnpgdwrnp5zkrfsjw7nr72wfjiemn6jstvxjk3k2pzejimr.py b/progress/SpecForge/cache/compiled_kernels/s6/cs65dxnpgdwrnp5zkrfsjw7nr72wfjiemn6jstvxjk3k2pzejimr.py new file mode 100644 index 0000000000000000000000000000000000000000..12f50733f02f239a62b62f3711874283dabfa81b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/s6/cs65dxnpgdwrnp5zkrfsjw7nr72wfjiemn6jstvxjk3k2pzejimr.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/kf/ckfyjfflfghyx5yrn2qsqwnamnhitwuy4eihaljnar5gcxcp63oq.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 912, 128][3735552, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 912, 128][933888, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 912, 128][933888, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 912][29184, 912, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 912][29184, 912, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:2" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:2" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (912, 912, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3735552, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 933888, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 933888, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 912 + ZKV = 1 + KV_LEN = 912 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 116736*idx_hq + 3735552*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l6/cl6xhyxtqtnczpdlmb3w5oimiq4gvtasqjswbgz4yakhnh4qkps7.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 912][29184, 912, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 233472}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 29184 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 912, 128), (3735552, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 912, 128), (933888, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 912, 128), (933888, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg4_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg5_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg6_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg7_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg8_1, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(arg9_1, (1, 1, 8), (8, 8, 1)) + assert_size_stride(arg10_1, (1, 1, 8, 8), (64, 64, 8, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, 912), (29184, 912, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 912), (29184, 912, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 912, 128), (3735552, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 8, 1, 32, stream=stream2) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, 29184, stream=stream2) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 912, 128), (3735552, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 912, 128), (933888, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 912, 128), (933888, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:2', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:2', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:2', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:2', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:2', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:2', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:2', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/s6/cs6lqarc6kmhcimzzs63vsmbfwzgweck3gfyvpxcuvp7ql4t5hya.py b/progress/SpecForge/cache/compiled_kernels/s6/cs6lqarc6kmhcimzzs63vsmbfwzgweck3gfyvpxcuvp7ql4t5hya.py new file mode 100644 index 0000000000000000000000000000000000000000..99d6f6678ab1e61c37bc567c8eeaae12fa4a6a48 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/s6/cs6lqarc6kmhcimzzs63vsmbfwzgweck3gfyvpxcuvp7ql4t5hya.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/s7/cs7r6qssasm4h7mbus75odr37ao4k4siad5njw4m4q75lkf2ogzz.py b/progress/SpecForge/cache/compiled_kernels/s7/cs7r6qssasm4h7mbus75odr37ao4k4siad5njw4m4q75lkf2ogzz.py new file mode 100644 index 0000000000000000000000000000000000000000..5eefbdf97e8973376eee4317fc9650e9eb9f5a1f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/s7/cs7r6qssasm4h7mbus75odr37ao4k4siad5njw4m4q75lkf2ogzz.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4d/c4d7fh2egdfps7aogbncwlp3ihfwtff243bbobq7vrxj2m2grl64.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wm/cwmpftoirovwcyanitdggeesxb3sl2cxyxzjwrzcvszzry65r7ts.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:1" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:1" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 7][7, 7, 1]cuda:1" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 7, 7][49, 49, 7, 1]cuda:1" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 7][7, 7, 1]cuda:1" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 7, 7][49, 49, 7, 1]cuda:1" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 7][7, 7, 1]cuda:1" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 7, 7][49, 49, 7, 1]cuda:1" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 7 + stride_q_idx_h = 49 + stride_q_idx_n = 7 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_15, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(primals_16, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_17, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(primals_18, (1, 1, 7), (7, 7, 1)) + assert_size_stride(primals_19, (1, 1, 7, 7), (49, 49, 7, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream1) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream1 = get_raw_stream(1) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream1) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 896 + primals_11 = 896 + primals_7 = 7 + primals_8 = 7 + primals_12 = 7 + primals_2 = rand_strided((1, 32, 896, 128), (3670016, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 896, 128), (917504, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 896, 128), (917504, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 7), (7, 7, 1), device='cuda:1', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 7, 7), (49, 49, 7, 1), device='cuda:1', dtype=torch.int32) + getitem = rand_strided((1, 32, 896, 128), (3670016, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 896), (28672, 896, 1), device='cuda:1', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 896, 128), (3670016, 114688, 128, 1), device='cuda:1', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 896), (28672, 896, 1), device='cuda:1', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/sc/csc67q546j7muhvoccled2oa4u3qywv2ugybi3n244etkkzozhm2.py b/progress/SpecForge/cache/compiled_kernels/sc/csc67q546j7muhvoccled2oa4u3qywv2ugybi3n244etkkzozhm2.py new file mode 100644 index 0000000000000000000000000000000000000000..c13ed38376308c0466ef6e5089fd4dacb2714f7e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sc/csc67q546j7muhvoccled2oa4u3qywv2ugybi3n244etkkzozhm2.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4521984, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1130496, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1130496, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1104 + ZKV = 1 + KV_LEN = 1104 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 141312*idx_hq + 4521984*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py b/progress/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py new file mode 100644 index 0000000000000000000000000000000000000000..ead4ca5b408561e8afe773976360f347f69d2665 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sc/cscnwzzlpcjsqvndc4tlfwact2ecwdimqtwu2vya2cnto5t7c7pi.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/sc/e0eadf73379a6d7fd37d51b792317ea0c7e01525502a6363ef5eb0c0327fc37f.best_config b/progress/SpecForge/cache/compiled_kernels/sc/e0eadf73379a6d7fd37d51b792317ea0c7e01525502a6363ef5eb0c0327fc37f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..40730bb99578020826990ee2fa0abea87202f3e0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sc/e0eadf73379a6d7fd37d51b792317ea0c7e01525502a6363ef5eb0c0327fc37f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 47, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sf/95529215d128928370fe0075f582276af086d6a127197fb21e4886172f267e2a.best_config b/progress/SpecForge/cache/compiled_kernels/sf/95529215d128928370fe0075f582276af086d6a127197fb21e4886172f267e2a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..5add235a87151b2559c5bc48bcda17ee438a4661 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sf/95529215d128928370fe0075f582276af086d6a127197fb21e4886172f267e2a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "2685c2d349c32243d4ee216505dfdf1e257d04d8316595ed69d4ca3499146788", "found_by_coordesc": false, "time_taken_ms": 48, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sf/csf7dseluvszsdrkislri4gzoxq6kv6nf7kvg2pyh3fuptdw4ejz.py b/progress/SpecForge/cache/compiled_kernels/sf/csf7dseluvszsdrkislri4gzoxq6kv6nf7kvg2pyh3fuptdw4ejz.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0f086cbf30524f5490c7355333c50459639e29 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sf/csf7dseluvszsdrkislri4gzoxq6kv6nf7kvg2pyh3fuptdw4ejz.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 11 + stride_q_idx_h = 121 + stride_q_idx_n = 11 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py b/progress/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py new file mode 100644 index 0000000000000000000000000000000000000000..9d35058338bacd4f34cea0224649b563e28330e6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/sf/csfiiaz4rrgyy7cuqppbaaifh36xb62bfukrdavoejklnec4xgl2.py b/progress/SpecForge/cache/compiled_kernels/sf/csfiiaz4rrgyy7cuqppbaaifh36xb62bfukrdavoejklnec4xgl2.py new file mode 100644 index 0000000000000000000000000000000000000000..751b1faa77c0f76ea0c1fd334e84587399c1840d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sf/csfiiaz4rrgyy7cuqppbaaifh36xb62bfukrdavoejklnec4xgl2.py @@ -0,0 +1,715 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gc/cgceywn6audeh4gx6twoifhmr6vj6e3pjftmmpzgsfjri6cwdxdv.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:0" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:0" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:0" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/g4/cg4y3bgvtlp62cdorpd6bc2yltyucyl4uv5uslettv26fielsmud.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg9_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg10_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg11_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg12_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg13_1, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(arg14_1, (1, 1, 9), (9, 9, 1)) + assert_size_stride(arg15_1, (1, 1, 9, 9), (81, 81, 9, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream0) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1120 + arg1_1 = rand_strided((1, 32, 1120, 128), (4587520, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 1120 + arg3_1 = rand_strided((1, 8, 1120, 128), (1146880, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg4_1 = 1120 + arg5_1 = rand_strided((1, 8, 1120, 128), (1146880, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg7_1 = 1120 + arg8_1 = 1120 + arg9_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/sh/0c7ce420e763b642c6279e9424cd4f7dfcd4bb2f2e5a46d22c9ba9e1aac0e943.best_config b/progress/SpecForge/cache/compiled_kernels/sh/0c7ce420e763b642c6279e9424cd4f7dfcd4bb2f2e5a46d22c9ba9e1aac0e943.best_config new file mode 100644 index 0000000000000000000000000000000000000000..77b32a09d39c99cc42575e211ae8fe06cc0fdf56 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sh/0c7ce420e763b642c6279e9424cd4f7dfcd4bb2f2e5a46d22c9ba9e1aac0e943.best_config @@ -0,0 +1 @@ +{"XBLOCK": 64, "R0_BLOCK": 64, "num_warps": 16, "num_stages": 1, "configs_hash": "38d69fe34c260c04bb09639cda54409e946793e79d4d066df3a182582fff66ed", "found_by_coordesc": false, "time_taken_ms": 41, "triton_cache_hash": "GBIQTIXLLLI56EMJONBW74RZJ42E6PTSU5N7LA23N4VBEBKK3HNQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py b/progress/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py new file mode 100644 index 0000000000000000000000000000000000000000..67dc44d59405aec8ba39c3442f3f5e545c8b66aa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sh/csh76hcjkj7bc6jvydzdmaapo6vnfxlvc3xvqexzngu63td4qnjk.py @@ -0,0 +1,51 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/sj/csjgoierbyb7y37xze3raek6tsz2nam2dgfupz3aqsd5fa3pyzto.py b/progress/SpecForge/cache/compiled_kernels/sj/csjgoierbyb7y37xze3raek6tsz2nam2dgfupz3aqsd5fa3pyzto.py new file mode 100644 index 0000000000000000000000000000000000000000..05e6d47e5abdabbc9c9e3d2661a34d79d7eb8a36 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sj/csjgoierbyb7y37xze3raek6tsz2nam2dgfupz3aqsd5fa3pyzto.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/sj/csjrlug4jqqsdp52xnorr2wjqndi3dvxf33umzbjv43vkqekggiz.py b/progress/SpecForge/cache/compiled_kernels/sj/csjrlug4jqqsdp52xnorr2wjqndi3dvxf33umzbjv43vkqekggiz.py new file mode 100644 index 0000000000000000000000000000000000000000..9c896ceed64f2513ba76c35416d66d84d3129bbe --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sj/csjrlug4jqqsdp52xnorr2wjqndi3dvxf33umzbjv43vkqekggiz.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sm/csmz2k3edklza74y46mapw3rjeds7hufvjfwpaldv36ivikgh5kl.py b/progress/SpecForge/cache/compiled_kernels/sm/csmz2k3edklza74y46mapw3rjeds7hufvjfwpaldv36ivikgh5kl.py new file mode 100644 index 0000000000000000000000000000000000000000..6861a6bdd1e250bb6dd1a59c6d8261cdf51bc797 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sm/csmz2k3edklza74y46mapw3rjeds7hufvjfwpaldv36ivikgh5kl.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/sv/csvjnfh5e5wghck6hiap6kf2di4dsachfvad46xenh7oekre4dxa.py b/progress/SpecForge/cache/compiled_kernels/sv/csvjnfh5e5wghck6hiap6kf2di4dsachfvad46xenh7oekre4dxa.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd01506c5cbbfb9600ead948d58145a6b0bd6e9 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sv/csvjnfh5e5wghck6hiap6kf2di4dsachfvad46xenh7oekre4dxa.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sz/11dc52b3ed5591712ce9ce6ab05944fddb63780a31a379a1e4315bbb9b4cd6eb.best_config b/progress/SpecForge/cache/compiled_kernels/sz/11dc52b3ed5591712ce9ce6ab05944fddb63780a31a379a1e4315bbb9b4cd6eb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d441143dba9e89d6d81a5c55ecd4b33917e9dc45 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/11dc52b3ed5591712ce9ce6ab05944fddb63780a31a379a1e4315bbb9b4cd6eb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sz/5cf92187c9236c4a1c239035b6bc29024a962f455823b8aa63c75565caa6ebc0.best_config b/progress/SpecForge/cache/compiled_kernels/sz/5cf92187c9236c4a1c239035b6bc29024a962f455823b8aa63c75565caa6ebc0.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/5cf92187c9236c4a1c239035b6bc29024a962f455823b8aa63c75565caa6ebc0.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sz/7f43e772e77219e377f2b3e1b6426eee8d455495d4aabb298e4fd8708fbefbcb.best_config b/progress/SpecForge/cache/compiled_kernels/sz/7f43e772e77219e377f2b3e1b6426eee8d455495d4aabb298e4fd8708fbefbcb.best_config new file mode 100644 index 0000000000000000000000000000000000000000..86180d8e7f10e5ba4d8d3e8abd8ff2ae83822ad0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/7f43e772e77219e377f2b3e1b6426eee8d455495d4aabb298e4fd8708fbefbcb.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "I32ECYH5I75REV4TRBMWHSFZKPTUF2XC4JIHGP5FR3TBMM6U2PKQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/sz/csz4oyqya4trc5ll55zjge2jpbfnpn34ci462kngl55bi66jokcx.py b/progress/SpecForge/cache/compiled_kernels/sz/csz4oyqya4trc5ll55zjge2jpbfnpn34ci462kngl55bi66jokcx.py new file mode 100644 index 0000000000000000000000000000000000000000..6c55964fc2c141201af703d32441842bade8b51c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/csz4oyqya4trc5ll55zjge2jpbfnpn34ci462kngl55bi66jokcx.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/sz/cszgqakbiplu4kezgnolrvjqw5rvgv44pau44u5c3nelyopoju4t.py b/progress/SpecForge/cache/compiled_kernels/sz/cszgqakbiplu4kezgnolrvjqw5rvgv44pau44u5c3nelyopoju4t.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd8303e7270547131328ec53df712813ecf97eb --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/cszgqakbiplu4kezgnolrvjqw5rvgv44pau44u5c3nelyopoju4t.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/sz/csztv2zsqb2ktn2e3qzo2qhm3igoajqarpmeyv4b47a7lltch4yf.py b/progress/SpecForge/cache/compiled_kernels/sz/csztv2zsqb2ktn2e3qzo2qhm3igoajqarpmeyv4b47a7lltch4yf.py new file mode 100644 index 0000000000000000000000000000000000000000..2157efbd5097052b7d264ae74ac67244bc16a246 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/sz/csztv2zsqb2ktn2e3qzo2qhm3igoajqarpmeyv4b47a7lltch4yf.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) diff --git a/progress/SpecForge/cache/compiled_kernels/t2/ct2ur7xvbkwpboszug7nn3jsmf2mlhxx7jrqovhxjmtr3sp6hrvw.py b/progress/SpecForge/cache/compiled_kernels/t2/ct2ur7xvbkwpboszug7nn3jsmf2mlhxx7jrqovhxjmtr3sp6hrvw.py new file mode 100644 index 0000000000000000000000000000000000000000..806e87bd20253a3fbb53a56311c86ba546295856 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/t2/ct2ur7xvbkwpboszug7nn3jsmf2mlhxx7jrqovhxjmtr3sp6hrvw.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4521984, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1130496, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1130496, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1104 + ZKV = 1 + KV_LEN = 1104 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 141312*idx_hq + 4521984*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/tb/ctbggfksobt2p2yz73uzwirpve64kqsqoftqfrpk7mxp7o74fnxh.py b/progress/SpecForge/cache/compiled_kernels/tb/ctbggfksobt2p2yz73uzwirpve64kqsqoftqfrpk7mxp7o74fnxh.py new file mode 100644 index 0000000000000000000000000000000000000000..23cbe725e93123a61d54b596511689ee71baad1f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tb/ctbggfksobt2p2yz73uzwirpve64kqsqoftqfrpk7mxp7o74fnxh.py @@ -0,0 +1,715 @@ +# AOT ID: ['4_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/lw/clwi4dqbrij6ek7edfo5fy635sclo6xa5a5yx4tuirr5mullbmva.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7a/c7a2brsshxp6zz4foe62t5ivwbd2dwr6ytjbhxp22vq2evdotx5z.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream4) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream4) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/tb/ctbxuobrle5iwzw7zg5iml5bl4u3vjpvqet2p3zbsg2cm3om2xdq.py b/progress/SpecForge/cache/compiled_kernels/tb/ctbxuobrle5iwzw7zg5iml5bl4u3vjpvqet2p3zbsg2cm3om2xdq.py new file mode 100644 index 0000000000000000000000000000000000000000..304ab22b77f541bcc29c23f14ed1ea5cdd780f46 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tb/ctbxuobrle5iwzw7zg5iml5bl4u3vjpvqet2p3zbsg2cm3om2xdq.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ka/cka7xx5r5hjbxmpyy422mg5htwwxacunbdxsuwsvx5i3hsleezxl.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csf7dseluvszsdrkislri4gzoxq6kv6nf7kvg2pyh3fuptdw4ejz.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:2" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:2" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:2" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:2" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:2" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:2" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 11 + stride_q_idx_h = 121 + stride_q_idx_n = 11 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_15, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_16, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_17, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_18, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_19, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream2) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1312 + primals_11 = 1312 + primals_7 = 11 + primals_8 = 11 + primals_12 = 11 + primals_2 = rand_strided((1, 32, 1312, 128), (5373952, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 1312, 128), (1343488, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 1312, 128), (1343488, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:2', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((1, 32, 1312, 128), (5373952, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1312), (41984, 1312, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1312, 128), (5373952, 167936, 128, 1), device='cuda:2', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1312), (41984, 1312, 1), device='cuda:2', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/td/3723c3c9db294ab3cc4a9f826a5ce940dbc5086714c4985bc8314ccef68e0bea.best_config b/progress/SpecForge/cache/compiled_kernels/td/3723c3c9db294ab3cc4a9f826a5ce940dbc5086714c4985bc8314ccef68e0bea.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d2d36089fa3ca5d31bfaf2bb97f1693f397c2328 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/td/3723c3c9db294ab3cc4a9f826a5ce940dbc5086714c4985bc8314ccef68e0bea.best_config @@ -0,0 +1 @@ +{"XBLOCK": 4, "R0_BLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 47, "triton_cache_hash": "KQBNSXG43774AG7AZSZ4NGC2ACGCHAORJKIGJ36XNGPJ6EBTHLFA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/td/ctdn3hijvlywgmamgcmc3dewg5hrtelyrdorwcy5pxw43gibtgkp.py b/progress/SpecForge/cache/compiled_kernels/td/ctdn3hijvlywgmamgcmc3dewg5hrtelyrdorwcy5pxw43gibtgkp.py new file mode 100644 index 0000000000000000000000000000000000000000..30c40cb37b75f7223571b4c7c5deb9084999d39f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/td/ctdn3hijvlywgmamgcmc3dewg5hrtelyrdorwcy5pxw43gibtgkp.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 724992, 'r0_': 30932992}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 60416 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 1888) + x1 = xindex // 1888 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/tg/b03f5af179fd205d7f111df27677ba72aeeb452ffb30bef6aa44125d4a9abe56.best_config b/progress/SpecForge/cache/compiled_kernels/tg/b03f5af179fd205d7f111df27677ba72aeeb452ffb30bef6aa44125d4a9abe56.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tg/b03f5af179fd205d7f111df27677ba72aeeb452ffb30bef6aa44125d4a9abe56.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/tl/ctlpu6dtplhnxq223amgvjrfwcqysnrjubory4o2ult7mshpvgyw.py b/progress/SpecForge/cache/compiled_kernels/tl/ctlpu6dtplhnxq223amgvjrfwcqysnrjubory4o2ult7mshpvgyw.py new file mode 100644 index 0000000000000000000000000000000000000000..791a73876e0ac1d4b3e088d07fdec1a15c7c8b19 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tl/ctlpu6dtplhnxq223amgvjrfwcqysnrjubory4o2ult7mshpvgyw.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/tm/ctm477d6lh6tb42ydiijd5zvrgpcrn5box5it3krvzamvoxevp7f.py b/progress/SpecForge/cache/compiled_kernels/tm/ctm477d6lh6tb42ydiijd5zvrgpcrn5box5it3krvzamvoxevp7f.py new file mode 100644 index 0000000000000000000000000000000000000000..c48d8615e9491edb5ce2aa7ad2fba8cd0e8f1d7d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tm/ctm477d6lh6tb42ydiijd5zvrgpcrn5box5it3krvzamvoxevp7f.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/tq/ctqcepmtvjaox3vwfoemyiaff4myknwhbejbs2txsxhov5lyloip.py b/progress/SpecForge/cache/compiled_kernels/tq/ctqcepmtvjaox3vwfoemyiaff4myknwhbejbs2txsxhov5lyloip.py new file mode 100644 index 0000000000000000000000000000000000000000..72ede8826df979166409ff76ac16c4fc9e6cb283 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tq/ctqcepmtvjaox3vwfoemyiaff4myknwhbejbs2txsxhov5lyloip.py @@ -0,0 +1,721 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/75/c754kjoskxv3cqqkmnl5v5i6bmsuwcudcec4lf36webo5fkp64ty.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %arg12_1 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=arg12_1] +# %arg8_1 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=arg8_1] +# %arg13_1 : Tensor "i32[1, 1, 4][4, 4, 1]cuda:5" = PlaceHolder[target=arg13_1] +# %arg14_1 : Tensor "i32[1, 1, 4, 4][16, 16, 4, 1]cuda:5" = PlaceHolder[target=arg14_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg9_1, %arg10_1, %arg12_1, %arg8_1, %arg13_1, %arg14_1, %arg15_1, %arg16_1, %arg17_1, %arg18_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tg/ctgd6c2ixeeuw4svuupxv4sfwufz5ikl6mu65w37fz23qyt4ckfb.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16384}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s22 = arg6_1 + s72 = arg7_1 + s37 = arg9_1 + s71 = arg10_1 + s99 = arg11_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg8_1, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(arg12_1, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(arg13_1, (1, 1, 4), (4, 4, 1)) + assert_size_stride(arg14_1, (1, 1, 4, 4), (16, 16, 4, 1)) + assert_size_stride(arg15_1, (1, 1, 4), (4, 4, 1)) + assert_size_stride(arg16_1, (1, 1, 4, 4), (16, 16, 4, 1)) + assert_size_stride(arg17_1, (1, 1, 4), (4, 4, 1)) + assert_size_stride(arg18_1, (1, 1, 4, 4), (16, 16, 4, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg12_1, arg8_1, arg13_1, arg14_1, buf2, s37, s0, s43, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream5) + del arg12_1 + del arg13_1 + del arg14_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg8_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream5) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 464 + arg1_1 = rand_strided((1, 32, 464, 128), (1900544, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = 464 + arg3_1 = rand_strided((1, 8, 464, 128), (475136, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg4_1 = 464 + arg5_1 = rand_strided((1, 8, 464, 128), (475136, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg6_1 = 4 + arg7_1 = 4 + arg8_1 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + arg9_1 = 464 + arg10_1 = 464 + arg11_1 = 4 + arg12_1 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + arg16_1 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + arg17_1 = rand_strided((1, 1, 4), (4, 4, 1), device='cuda:5', dtype=torch.int32) + arg18_1 = rand_strided((1, 1, 4, 4), (16, 16, 4, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ty/32d4f14fc58317f5988e8c3684cea02c8b49eb2508e2707c2187261a2a5f665b.best_config b/progress/SpecForge/cache/compiled_kernels/ty/32d4f14fc58317f5988e8c3684cea02c8b49eb2508e2707c2187261a2a5f665b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f3cda7a0e24fbe317fea6a7e3c8ec4e92f936acc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ty/32d4f14fc58317f5988e8c3684cea02c8b49eb2508e2707c2187261a2a5f665b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 24, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ty/ctye46k4gahd2e44muti52bzvk5ppd4o67yh77qc74mlf4iqkey7.py b/progress/SpecForge/cache/compiled_kernels/ty/ctye46k4gahd2e44muti52bzvk5ppd4o67yh77qc74mlf4iqkey7.py new file mode 100644 index 0000000000000000000000000000000000000000..ec177b5b669176a8b5c9cddd2d94e3089b98a076 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ty/ctye46k4gahd2e44muti52bzvk5ppd4o67yh77qc74mlf4iqkey7.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/tz/ctzckrbdrxl2vvmkrparzusj5jevllalx5mgn4xod5rvzfn4pkyf.py b/progress/SpecForge/cache/compiled_kernels/tz/ctzckrbdrxl2vvmkrparzusj5jevllalx5mgn4xod5rvzfn4pkyf.py new file mode 100644 index 0000000000000000000000000000000000000000..91fe86e14df70510ff89301ec95856974727deb3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/tz/ctzckrbdrxl2vvmkrparzusj5jevllalx5mgn4xod5rvzfn4pkyf.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/u2/40cc2a7d5c8ad6f9a23d9d5c3a69d1f155b85e07148a08a78cc51799790b72a5.best_config b/progress/SpecForge/cache/compiled_kernels/u2/40cc2a7d5c8ad6f9a23d9d5c3a69d1f155b85e07148a08a78cc51799790b72a5.best_config new file mode 100644 index 0000000000000000000000000000000000000000..38d3904992118dfdfceb98c87234684feb5244a1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u2/40cc2a7d5c8ad6f9a23d9d5c3a69d1f155b85e07148a08a78cc51799790b72a5.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 58, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/u2/cu2azlmlb5dv6qaawuhbij2p4nat3p7qd3slteyjah2xcdvoeanh.py b/progress/SpecForge/cache/compiled_kernels/u2/cu2azlmlb5dv6qaawuhbij2p4nat3p7qd3slteyjah2xcdvoeanh.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe81e8beb232e423c20c33b4ab5c7df50d986fe --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u2/cu2azlmlb5dv6qaawuhbij2p4nat3p7qd3slteyjah2xcdvoeanh.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/u2/cu2syvrflgorfkp2lmrr77y5zdy2o6j3zispgwqzaeqv2gbu5mvx.py b/progress/SpecForge/cache/compiled_kernels/u2/cu2syvrflgorfkp2lmrr77y5zdy2o6j3zispgwqzaeqv2gbu5mvx.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b9479b8f8e37f85ea863ce15cd8504783f572e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u2/cu2syvrflgorfkp2lmrr77y5zdy2o6j3zispgwqzaeqv2gbu5mvx.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/u7/cu76pkedlam7gs4ku5zq6m5zlh6t3rdjzbnfs5iu7hyafsyiq73r.py b/progress/SpecForge/cache/compiled_kernels/u7/cu76pkedlam7gs4ku5zq6m5zlh6t3rdjzbnfs5iu7hyafsyiq73r.py new file mode 100644 index 0000000000000000000000000000000000000000..a6dca32c77581eaa320b4f97c8ac41254f1e31da --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u7/cu76pkedlam7gs4ku5zq6m5zlh6t3rdjzbnfs5iu7hyafsyiq73r.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/u7/cu7o2hdud4yjwjkdwdhjnkmaco2z7ec54oaeglpmphgp6uwrjkn4.py b/progress/SpecForge/cache/compiled_kernels/u7/cu7o2hdud4yjwjkdwdhjnkmaco2z7ec54oaeglpmphgp6uwrjkn4.py new file mode 100644 index 0000000000000000000000000000000000000000..c10543644727b324d0834b2df4cd1fdec7d72453 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u7/cu7o2hdud4yjwjkdwdhjnkmaco2z7ec54oaeglpmphgp6uwrjkn4.py @@ -0,0 +1,715 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ew/cewn2ayqifa2nhwgosykqxi6glpfgcflbcihzqk4harxwe5hv5a3.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:7" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:7" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:7" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:7" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:7" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hc/chcapoqvma4vr7nyxct77gqu7ncybp4lj37z3ip24ltc4pdvsp6v.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg9_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg10_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg11_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg12_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg13_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg14_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg15_1, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream7) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream7) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 688 + arg1_1 = rand_strided((1, 32, 688, 128), (2818048, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + arg2_1 = 688 + arg3_1 = rand_strided((1, 8, 688, 128), (704512, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg4_1 = 688 + arg5_1 = rand_strided((1, 8, 688, 128), (704512, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg7_1 = 688 + arg8_1 = 688 + arg9_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:7', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/u7/cu7sdfz5zirqjkk3xofseufd2v3mtab5c6s5hsbn3kzvmfsbndy6.py b/progress/SpecForge/cache/compiled_kernels/u7/cu7sdfz5zirqjkk3xofseufd2v3mtab5c6s5hsbn3kzvmfsbndy6.py new file mode 100644 index 0000000000000000000000000000000000000000..90cc530d2894dd8d8ad0d76860c17ab527a66aba --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/u7/cu7sdfz5zirqjkk3xofseufd2v3mtab5c6s5hsbn3kzvmfsbndy6.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/si/csiqxh5yfqziyhs4zfuiszx7ckahzxji737hjp4j5cz7dsebx4cg.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 768, 128][3145728, 128, 4096, 1]cuda:5" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 768, 128][786432, 128, 1024, 1]cuda:5" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:5" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:5" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:5" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:5" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (768, 768, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 98304*idx_hq + 3145728*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sz/csztv2zsqb2ktn2e3qzo2qhm3igoajqarpmeyv4b47a7lltch4yf.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 768][24576, 768, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 196608}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 24576 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = tl.full([XBLOCK], True, tl.int1) + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 768, 128), (3145728, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 768, 128), (786432, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg4_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg5_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg6_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg7_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg8_1, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(arg9_1, (1, 1, 6), (6, 6, 1)) + assert_size_stride(arg10_1, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 768), (24576, 768, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 768, 128), (3145728, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 6, 1, 32, stream=stream5) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, 24576, stream=stream5) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 768, 128), (3145728, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 768, 128), (786432, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:5', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ua/1411d6e2ff5fd595aed809997875cb427a8d1668a5e7aff8a2dcc88f0ccf9973.best_config b/progress/SpecForge/cache/compiled_kernels/ua/1411d6e2ff5fd595aed809997875cb427a8d1668a5e7aff8a2dcc88f0ccf9973.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d131a13c7786a0afd9ad24a5628260bf7969ba42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ua/1411d6e2ff5fd595aed809997875cb427a8d1668a5e7aff8a2dcc88f0ccf9973.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ua/cua7yz7dxwt2i3zndnhksibxtopafloy2iioog52rg7qunj23nn7.py b/progress/SpecForge/cache/compiled_kernels/ua/cua7yz7dxwt2i3zndnhksibxtopafloy2iioog52rg7qunj23nn7.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ca4a6307304ac66a5849c976132a3c96d75ad8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ua/cua7yz7dxwt2i3zndnhksibxtopafloy2iioog52rg7qunj23nn7.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/ua/cuak4xnl32sznllbz7nmpi33f77bskrx4jt545ttz3xmdzenvfhr.py b/progress/SpecForge/cache/compiled_kernels/ua/cuak4xnl32sznllbz7nmpi33f77bskrx4jt545ttz3xmdzenvfhr.py new file mode 100644 index 0000000000000000000000000000000000000000..9e779f5fa9f481fc0b370b3cb141fbecabfe9648 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ua/cuak4xnl32sznllbz7nmpi33f77bskrx4jt545ttz3xmdzenvfhr.py @@ -0,0 +1,879 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hn/chnb6aeoibwfdajl2bc2brixhsgsiehiaisuunjcavjkmuxf45kd.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:1" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:1" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:1" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/wp/cwpp5u6vrzuvzfyrgcsc3peqmsqtexvpdosjke3rbbnyl7sl3zax.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iv/civspn3mu5mvev3iue5ugsniof2fkp4gxubekpwhysago5i7xsc7.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:1" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:1" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:1" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:1"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream1 = get_raw_stream(1) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream1) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream1 = get_raw_stream(1) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream1) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream1 = get_raw_stream(1) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream1) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 112 + arg1_1 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + arg2_1 = 112 + arg3_1 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg4_1 = 112 + arg5_1 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg7_1 = 112 + arg8_1 = 112 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:1', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:1', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ub/cubpmxjy3cjppf6cr5nwsk75o5nvlhrlaoeys62x3sksiyf5xonc.py b/progress/SpecForge/cache/compiled_kernels/ub/cubpmxjy3cjppf6cr5nwsk75o5nvlhrlaoeys62x3sksiyf5xonc.py new file mode 100644 index 0000000000000000000000000000000000000000..fcef408fc478f55cc06ee39e0e224ea0a1dcbc1a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ub/cubpmxjy3cjppf6cr5nwsk75o5nvlhrlaoeys62x3sksiyf5xonc.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sm/csmz2k3edklza74y46mapw3rjeds7hufvjfwpaldv36ivikgh5kl.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:5" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:5" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/qt/cqtdareqz6tivounhz5hfeqyiotaxvjxc2v3hngtznpiz5nqvksa.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_15, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_16, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_17, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_18, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_19, (1, 1, 9, 9), (81, 81, 9, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream5) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream5) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1104 + primals_2 = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = 1104 + primals_4 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = 1104 + primals_6 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = 9 + primals_8 = 9 + primals_9 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_10 = 1104 + primals_11 = 1104 + primals_12 = 9 + primals_13 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ub/cubr5qxynzjxlx2y5h2i5skpq64ikutw4flhf52gnghyboe6uocs.py b/progress/SpecForge/cache/compiled_kernels/ub/cubr5qxynzjxlx2y5h2i5skpq64ikutw4flhf52gnghyboe6uocs.py new file mode 100644 index 0000000000000000000000000000000000000000..51b624da8089cb9c697c5b02cd01d78c435c5efd --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ub/cubr5qxynzjxlx2y5h2i5skpq64ikutw4flhf52gnghyboe6uocs.py @@ -0,0 +1,1018 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 976, 128][3997696, 124928, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (976, 976, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 374784, 'r0_': 15990784}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 31232 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 976) + x1 = xindex // 976 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/om/comxwt7qmz7r76sshdlb7eiymhymj24rvyjhgcvd4d54rii5xcxm.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 976, 128][3997696, 124928, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_8 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (976, 976, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 999424, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3997696, 124928, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3997696, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 8 + stride_q_idx_h = 64 + stride_q_idx_n = 8 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 124928*off_hkv + 999424*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_5, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_6, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_7, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_8, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_9, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_10, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_11, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(getitem, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 976), (31232, 976, 1)) + assert_size_stride(tangents_1, (1, 32, 976, 128), (3997696, 124928, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 976), (31232, 976, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 31232, 128, stream=stream7) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 976, 128), (999424, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 976, 128), (999424, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 40, 1, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 976), (31232, 976, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 976, 128), (3997696, 124928, 128, 1), device='cuda:7', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 976), (31232, 976, 1), device='cuda:7', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ub/cubrsuz4vfy4ff4xn7qrblxgooyak2quzvta3hr3vztgm3yeazir.py b/progress/SpecForge/cache/compiled_kernels/ub/cubrsuz4vfy4ff4xn7qrblxgooyak2quzvta3hr3vztgm3yeazir.py new file mode 100644 index 0000000000000000000000000000000000000000..8de5571d824241039eb668d2904ac499667d920a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ub/cubrsuz4vfy4ff4xn7qrblxgooyak2quzvta3hr3vztgm3yeazir.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ud/cudpkvrfsavs5fzglpxru5ms2bwn5cuyexc7hataogv2zicg4bda.py b/progress/SpecForge/cache/compiled_kernels/ud/cudpkvrfsavs5fzglpxru5ms2bwn5cuyexc7hataogv2zicg4bda.py new file mode 100644 index 0000000000000000000000000000000000000000..7928f79bb258a99d3db3863d82727bff59c9c941 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ud/cudpkvrfsavs5fzglpxru5ms2bwn5cuyexc7hataogv2zicg4bda.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 3145728, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 786432, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 786432, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 768 + ZKV = 1 + KV_LEN = 768 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 6 + stride_kv_idx_h = 36 + stride_kv_idx_m = 6 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 98304*idx_hq + 3145728*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ue/cueevzv7vmofb7xgliazywafcctru3ytzoxylqvshvpe6nweecz6.py b/progress/SpecForge/cache/compiled_kernels/ue/cueevzv7vmofb7xgliazywafcctru3ytzoxylqvshvpe6nweecz6.py new file mode 100644 index 0000000000000000000000000000000000000000..44df908b628d40625dec8b6f309607c486276ce2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ue/cueevzv7vmofb7xgliazywafcctru3ytzoxylqvshvpe6nweecz6.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ug/cugsdz6ngnkagg3e7qtsoxgk3tpt6tu3s7opnd7ctcm5xsdh4xit.py b/progress/SpecForge/cache/compiled_kernels/ug/cugsdz6ngnkagg3e7qtsoxgk3tpt6tu3s7opnd7ctcm5xsdh4xit.py new file mode 100644 index 0000000000000000000000000000000000000000..1b794386c4a4354e37cf91375c98e0f4e1780b8a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ug/cugsdz6ngnkagg3e7qtsoxgk3tpt6tu3s7opnd7ctcm5xsdh4xit.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4980736, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1245184, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1245184, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1216 + ZKV = 1 + KV_LEN = 1216 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 10 + stride_kv_idx_h = 100 + stride_kv_idx_m = 10 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 155648*idx_hq + 4980736*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/uj/49010af2885e87d3ab761092d5518a27fd02638fe60871876851c35f7f6eb839.best_config b/progress/SpecForge/cache/compiled_kernels/uj/49010af2885e87d3ab761092d5518a27fd02638fe60871876851c35f7f6eb839.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c090b439cd98894f50ef817b2f3b529d58dbee1e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uj/49010af2885e87d3ab761092d5518a27fd02638fe60871876851c35f7f6eb839.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 89, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/uj/cujc5sujk2k72m34azlr6o6up4djfa3r72g4pqs7juagmri5ed7p.py b/progress/SpecForge/cache/compiled_kernels/uj/cujc5sujk2k72m34azlr6o6up4djfa3r72g4pqs7juagmri5ed7p.py new file mode 100644 index 0000000000000000000000000000000000000000..2582cc84a438223da6d26e0e8c52e6b78b51ad86 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uj/cujc5sujk2k72m34azlr6o6up4djfa3r72g4pqs7juagmri5ed7p.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/uj/cujn5acctg5nu6gashmhyky6ld5c64hfh4c3d2k6ma6sfsg7yo5x.py b/progress/SpecForge/cache/compiled_kernels/uj/cujn5acctg5nu6gashmhyky6ld5c64hfh4c3d2k6ma6sfsg7yo5x.py new file mode 100644 index 0000000000000000000000000000000000000000..adf07edbbc8faa0dc6d82c00ac3b9a03ef792118 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uj/cujn5acctg5nu6gashmhyky6ld5c64hfh4c3d2k6ma6sfsg7yo5x.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 262144, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/ul/culcv2qdlct4yajrlsad7gaww5dmis5fk73mjo4wp6zhkcxzoc26.py b/progress/SpecForge/cache/compiled_kernels/ul/culcv2qdlct4yajrlsad7gaww5dmis5fk73mjo4wp6zhkcxzoc26.py new file mode 100644 index 0000000000000000000000000000000000000000..3088c153bd2f83e694ceb6c8bb7cfa6253f5b4aa --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ul/culcv2qdlct4yajrlsad7gaww5dmis5fk73mjo4wp6zhkcxzoc26.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/uo/cuobxprbubloh4bvbpalx7xyphdph65h3ovmdtjcxienk75ylwk5.py b/progress/SpecForge/cache/compiled_kernels/uo/cuobxprbubloh4bvbpalx7xyphdph65h3ovmdtjcxienk75ylwk5.py new file mode 100644 index 0000000000000000000000000000000000000000..5b84a55bc42eefe34a6a324d1f93eae302cd0e86 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uo/cuobxprbubloh4bvbpalx7xyphdph65h3ovmdtjcxienk75ylwk5.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jy/cjykix45b4axvnwiqzx22lsjfok37qykfnnjytv5a3mbo4d27i6o.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hb/chb3u4duw7nywxkok7ma3jaz7y56uys3waffebamb6kmmcygyno5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:5" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:5" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:5" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:5" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:5" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:5" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 9 + stride_q_idx_h = 81 + stride_q_idx_n = 9 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_15, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_16, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_17, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_18, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_19, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream5 = get_raw_stream(5) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream5) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream5 = get_raw_stream(5) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1104 + primals_11 = 1104 + primals_7 = 9 + primals_8 = 9 + primals_12 = 9 + primals_2 = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:5', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:5', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1104), (35328, 1104, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1104, 128), (4521984, 141312, 128, 1), device='cuda:5', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1104), (35328, 1104, 1), device='cuda:5', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/up/cuphnjzv4azjtj4fg6iv7i7iunp77atpptkpjdijpuwisop7xsjz.py b/progress/SpecForge/cache/compiled_kernels/up/cuphnjzv4azjtj4fg6iv7i7iunp77atpptkpjdijpuwisop7xsjz.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa53401438ef66dafbb547805c2e4a600d06617 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/up/cuphnjzv4azjtj4fg6iv7i7iunp77atpptkpjdijpuwisop7xsjz.py @@ -0,0 +1,1018 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zx/czxnzxgrxghuyq6lg4yqhvr53wk5msdvvzbzikyl5wfl7qctkl5l.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 1104, 128][4521984, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 1104, 128][4521984, 141312, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 1104][36864, 1152, 1]cuda:3" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1104, 1104, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 423936, 'r0_': 18087936}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 35328 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 1104) + x1 = xindex // 1104 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/q5/cq57polmfnqgdtojqa4xuffcmv4lgydnwkcwqvijbd3i74sdlt6l.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1104, 128][4521984, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 1104, 128][4521984, 141312, 128, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 1104, 128][4521984, 128, 4096, 1]cuda:3" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 1104, 128][1130496, 128, 1024, 1]cuda:3" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_8 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 9][9, 9, 1]cuda:3" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 9, 9][81, 81, 9, 1]cuda:3" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 1104][35328, 1104, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (1104, 1104, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4521984, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1130496, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1130496, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4521984, 141312, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4521984, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1130496, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 1104 + ZKV = 1 + KV_LEN = 1104 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 9 + stride_kv_idx_h = 81 + stride_kv_idx_m = 9 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 9 + stride_q_idx_h = 81 + stride_q_idx_n = 9 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 141312*off_hkv + 1130496*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1104 + KV_LEN = 1104 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 1104 + KV_LEN = 1104 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1104, 128), (4521984, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1104, 128), (1130496, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_5, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_6, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_7, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_8, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_9, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(primals_10, (1, 1, 9), (9, 9, 1)) + assert_size_stride(primals_11, (1, 1, 9, 9), (81, 81, 9, 1)) + assert_size_stride(getitem, (1, 32, 1104, 128), (4521984, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 1104), (35328, 1104, 1)) + assert_size_stride(tangents_1, (1, 32, 1104, 128), (4521984, 141312, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 1104), (35328, 1104, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf1 = empty_strided_cuda((1, 32, 1104), (35328, 1104, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 35328, 128, stream=stream3) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 1104, 128), (4521984, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 1104, 128), (1130496, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 1104, 128), (1130496, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 45, 1, 8, stream=stream3) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1104, 128), (1130496, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 9), (9, 9, 1), device='cuda:3', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 9, 9), (81, 81, 9, 1), device='cuda:3', dtype=torch.int32) + getitem = rand_strided((1, 32, 1104, 128), (4521984, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1104), (35328, 1104, 1), device='cuda:3', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1104, 128), (4521984, 141312, 128, 1), device='cuda:3', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1104), (35328, 1104, 1), device='cuda:3', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ur/curnpvcp3l7zkfuholahbfeoo4lyggobhvdyz6ge7ldxyafxglww.py b/progress/SpecForge/cache/compiled_kernels/ur/curnpvcp3l7zkfuholahbfeoo4lyggobhvdyz6ge7ldxyafxglww.py new file mode 100644 index 0000000000000000000000000000000000000000..597142daad84d19a9d891fb9153e0a39bd6878e3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ur/curnpvcp3l7zkfuholahbfeoo4lyggobhvdyz6ge7ldxyafxglww.py @@ -0,0 +1,1019 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/br/cbrxfgmhm5zh3kefpn33kxgk75xcaq7mfgc2olsxvpava77xzvon.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/dq/cdqxgmzsrx3vj3dcv3psyadqd5xk6gnxzczmdxxbkf2jxhszh75x.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream6) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream6 = get_raw_stream(6) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 112 + primals_9 = 112 + primals_2 = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 112, 128), (114688, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((1, 32, 112, 128), (458752, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 112, 128), (458752, 14336, 128, 1), device='cuda:6', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 112), (3584, 112, 1), device='cuda:6', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/ur/curxtqihdmdxm3ite3toi4y7h6mjg2gaohglhi6jf4hl6amaks76.py b/progress/SpecForge/cache/compiled_kernels/ur/curxtqihdmdxm3ite3toi4y7h6mjg2gaohglhi6jf4hl6amaks76.py new file mode 100644 index 0000000000000000000000000000000000000000..218321ba6b0de987d076fdcef7e4e276c2710ed7 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ur/curxtqihdmdxm3ite3toi4y7h6mjg2gaohglhi6jf4hl6amaks76.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/uu/2a063bfb13ee089555fb897c81e20dc037e3211afdce17f993e996b2d9ff9d57.best_config b/progress/SpecForge/cache/compiled_kernels/uu/2a063bfb13ee089555fb897c81e20dc037e3211afdce17f993e996b2d9ff9d57.best_config new file mode 100644 index 0000000000000000000000000000000000000000..291e362253fa48ebc63fe12d2b1707894ca5e923 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uu/2a063bfb13ee089555fb897c81e20dc037e3211afdce17f993e996b2d9ff9d57.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py b/progress/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py new file mode 100644 index 0000000000000000000000000000000000000000..da208b98d2ccc859f59225d85c00d8ab3ca66a6e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uu/cuu2rr2yygwarlbfvcbucg7erbfsky4wxudbfsdny5wzgxewg4ut.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/uu/cuuotz7u35z3shm5obm5xcxpvcb4wthpvaehs6cyvxob4qcaocry.py b/progress/SpecForge/cache/compiled_kernels/uu/cuuotz7u35z3shm5obm5xcxpvcb4wthpvaehs6cyvxob4qcaocry.py new file mode 100644 index 0000000000000000000000000000000000000000..311662212246ffcfe685094061cefb9b4c2cc3e3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uu/cuuotz7u35z3shm5obm5xcxpvcb4wthpvaehs6cyvxob4qcaocry.py @@ -0,0 +1,879 @@ +# AOT ID: ['3_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/fy/cfyocykti4yhb5mgvyvyw6bnmdpc66bcxv6j5bk5ncbu63xljovz.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg5_1] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:4" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:4" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cz/ccznsfhqqvjcsiqwieot6ti7mbssrwnaisfi52ehkwx66zkk2vuk.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul_9 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/cj/ccjnuksl3elmgei4is5wcep6ywju2v3wqhpxexxbjwglhwyy7fbx.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:4" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:4" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, 8, 32, 1, stream=stream4) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream4) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream4 = get_raw_stream(4) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream4) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 80 + arg1_1 = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = 80 + arg3_1 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg4_1 = 80 + arg5_1 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg7_1 = 80 + arg8_1 = 80 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:4', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/uu/cuuphv47i4uyclwboo4osj65ngxk6w662itmcl6uvwtstfmp6b5w.py b/progress/SpecForge/cache/compiled_kernels/uu/cuuphv47i4uyclwboo4osj65ngxk6w662itmcl6uvwtstfmp6b5w.py new file mode 100644 index 0000000000000000000000000000000000000000..a89c572242afac84cba98314cec88077bc2ca91e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uu/cuuphv47i4uyclwboo4osj65ngxk6w662itmcl6uvwtstfmp6b5w.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/uw/cuwcv4h4vakxcuyvc6if7npzzbojjcxco6syj4bpqrvtvsdbfcus.py b/progress/SpecForge/cache/compiled_kernels/uw/cuwcv4h4vakxcuyvc6if7npzzbojjcxco6syj4bpqrvtvsdbfcus.py new file mode 100644 index 0000000000000000000000000000000000000000..3c35150e0f50712ad76cfe9a616b5c8c62bc80b8 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uw/cuwcv4h4vakxcuyvc6if7npzzbojjcxco6syj4bpqrvtvsdbfcus.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gb/cgb3qo2uhbvd22jllo3pcd4qrm2qklyyeuprletllykljldff6hp.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1712, 128][7012352, 128, 4096, 1]cuda:2" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1712, 128][1753088, 128, 1024, 1]cuda:2" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 14][14, 14, 1]cuda:2" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 14, 14][196, 196, 14, 1]cuda:2" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1712, 1712, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7012352, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1753088, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1753088, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1712 + ZKV = 1 + KV_LEN = 1712 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 14 + stride_kv_idx_h = 196 + stride_kv_idx_m = 14 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 219136*idx_hq + 7012352*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/bj/cbjrt5iaqfrwd545ebmz2cjez2n7o26733qo7khjrhewzsdlew7f.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1712][54784, 1712, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 438272}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 54784 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1712, 128), (7012352, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1712, 128), (1753088, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1712, 128), (1753088, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 14), (14, 14, 1)) + assert_size_stride(arg4_1, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(arg5_1, (1, 1, 14), (14, 14, 1)) + assert_size_stride(arg6_1, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(arg7_1, (1, 1, 14), (14, 14, 1)) + assert_size_stride(arg8_1, (1, 1, 14, 14), (196, 196, 14, 1)) + assert_size_stride(arg9_1, (1, 1, 14), (14, 14, 1)) + assert_size_stride(arg10_1, (1, 1, 14, 14), (196, 196, 14, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1712), (54784, 1712, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1712, 128), (7012352, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 14, 1, 32, stream=stream2) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, 54784, stream=stream2) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1712, 128), (7012352, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1712, 128), (1753088, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 14), (14, 14, 1), device='cuda:2', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 14, 14), (196, 196, 14, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/uz/cuzdztvtrcfpxhjbqtna5oap4csrrss47lhi6pbhbuclhrs6pht6.py b/progress/SpecForge/cache/compiled_kernels/uz/cuzdztvtrcfpxhjbqtna5oap4csrrss47lhi6pbhbuclhrs6pht6.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b4909626cb63b84c87bf07aa34f952bb1ff237 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/uz/cuzdztvtrcfpxhjbqtna5oap4csrrss47lhi6pbhbuclhrs6pht6.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/v7/cv7bq24vw3lbgqi552cucrf4xp5es3suk3lkvxoyixnnlwluqxpm.py b/progress/SpecForge/cache/compiled_kernels/v7/cv7bq24vw3lbgqi552cucrf4xp5es3suk3lkvxoyixnnlwluqxpm.py new file mode 100644 index 0000000000000000000000000000000000000000..de5d9a038b823a337d73e7f6fe9ddecf69c42194 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/v7/cv7bq24vw3lbgqi552cucrf4xp5es3suk3lkvxoyixnnlwluqxpm.py @@ -0,0 +1,694 @@ +# AOT ID: ['0_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ci/ccilbfhuqcuajlbda4kwxy44prh4kt4c2tfzmwu7dbyxf76voqar.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 1488, 128][6094848, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 1488, 128][1523712, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_4 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_5] +# %primals_6 : Tensor "i32[1, 1, 12][12, 12, 1]cuda:0" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 12, 12][144, 144, 12, 1]cuda:0" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (1488, 1488, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 6094848, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1523712, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1523712, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1488 + ZKV = 1 + KV_LEN = 1488 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 12 + stride_kv_idx_h = 144 + stride_kv_idx_m = 12 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 190464*idx_hq + 6094848*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zj/czjxl4o5zq2vmca3k2ha3vhclo3rwmhvxomcpzhwqmwuamwvvilm.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1488][47616, 1488, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 380928}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 47616 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 1488, 128), (6094848, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 1488, 128), (1523712, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_5, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_6, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_7, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_8, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_9, (1, 1, 12, 12), (144, 144, 12, 1)) + assert_size_stride(primals_10, (1, 1, 12), (12, 12, 1)) + assert_size_stride(primals_11, (1, 1, 12, 12), (144, 144, 12, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1488), (47616, 1488, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1488, 128), (6094848, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_4, primals_5, primals_6, primals_7, buf2, 12, 1, 32, stream=stream0) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, 47616, stream=stream0) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 1488, 128), (6094848, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 1488, 128), (1523712, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 12), (12, 12, 1), device='cuda:0', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 12, 12), (144, 144, 12, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/va/cvafcw6cv2o3tw6rshgkv7ewebtivbhjqnis53gw3o2prxtrjnha.py b/progress/SpecForge/cache/compiled_kernels/va/cvafcw6cv2o3tw6rshgkv7ewebtivbhjqnis53gw3o2prxtrjnha.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac63bb7689c5927b58548e351298abb4dd2963e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/va/cvafcw6cv2o3tw6rshgkv7ewebtivbhjqnis53gw3o2prxtrjnha.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 13 + stride_q_idx_h = 169 + stride_q_idx_n = 13 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/ve/cvekecsr2xyzw43nnxda3edf2n52ypuxglntdv5ezkc3qxj2h3lt.py b/progress/SpecForge/cache/compiled_kernels/ve/cvekecsr2xyzw43nnxda3edf2n52ypuxglntdv5ezkc3qxj2h3lt.py new file mode 100644 index 0000000000000000000000000000000000000000..ac90955e08a4a197ed08be4aef343ab0d30156b5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ve/cvekecsr2xyzw43nnxda3edf2n52ypuxglntdv5ezkc3qxj2h3lt.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 79872*idx_hq + 2555904*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/vj/cvj4vgdwmfi5ohh7mn6zi4ttb7xw7ozyhs3to7ugrnlgkitqtkp2.py b/progress/SpecForge/cache/compiled_kernels/vj/cvj4vgdwmfi5ohh7mn6zi4ttb7xw7ozyhs3to7ugrnlgkitqtkp2.py new file mode 100644 index 0000000000000000000000000000000000000000..7c85c881e3c69fa443a8c49cc13f55bf6bb79394 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/vj/cvj4vgdwmfi5ohh7mn6zi4ttb7xw7ozyhs3to7ugrnlgkitqtkp2.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/vr/cvrfb2q74qdycranqio6w5hedir6jf3qhnk3i6zf525om3i55llm.py b/progress/SpecForge/cache/compiled_kernels/vr/cvrfb2q74qdycranqio6w5hedir6jf3qhnk3i6zf525om3i55llm.py new file mode 100644 index 0000000000000000000000000000000000000000..8aab1585b82401f9166e68f1b74cd2ff5856f8bf --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/vr/cvrfb2q74qdycranqio6w5hedir6jf3qhnk3i6zf525om3i55llm.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/vw/cvw37c5o5w4a4pqoaipfqynicrxcxgsu4e5qszvgikzwey42hrqy.py b/progress/SpecForge/cache/compiled_kernels/vw/cvw37c5o5w4a4pqoaipfqynicrxcxgsu4e5qszvgikzwey42hrqy.py new file mode 100644 index 0000000000000000000000000000000000000000..b72e78525b435a56ae972e44228900f48a2a570c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/vw/cvw37c5o5w4a4pqoaipfqynicrxcxgsu4e5qszvgikzwey42hrqy.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/vx/cvx7b5luqi3solyr5bc6yhgrkk5aapzgjr2gnvhxlvckcb4gzxj2.py b/progress/SpecForge/cache/compiled_kernels/vx/cvx7b5luqi3solyr5bc6yhgrkk5aapzgjr2gnvhxlvckcb4gzxj2.py new file mode 100644 index 0000000000000000000000000000000000000000..4d68743c44d9833a9af59750ba3d249f69caa504 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/vx/cvx7b5luqi3solyr5bc6yhgrkk5aapzgjr2gnvhxlvckcb4gzxj2.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hq/chqmxyert4bgtobyyi5cbjgrutmamoo4uk7ilsavd4ghz5g3vd2i.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/e4/ce4r5alb7hdiabxo2umun2fjkjfew7uwhmxhgzudgs2hgl4d2ze5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:4" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:4" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:4" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:4" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:4" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:4" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 10 + stride_q_idx_h = 100 + stride_q_idx_n = 10 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_15, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(primals_16, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_17, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(primals_18, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_19, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream4) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1280 + primals_11 = 1280 + primals_7 = 10 + primals_8 = 10 + primals_12 = 10 + primals_2 = rand_strided((1, 32, 1280, 128), (5242880, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 1280, 128), (1310720, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 1280, 128), (1310720, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((1, 32, 1280, 128), (5242880, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1280), (40960, 1280, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1280, 128), (5242880, 163840, 128, 1), device='cuda:4', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1280), (40960, 1280, 1), device='cuda:4', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/w4/cw4fsmrzlz3yxap27ofc3ad2deoo2c6ok5s5m63cejtspevv4x7z.py b/progress/SpecForge/cache/compiled_kernels/w4/cw4fsmrzlz3yxap27ofc3ad2deoo2c6ok5s5m63cejtspevv4x7z.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9e373591ca7fd8b2808dfc2c4979f36a3451c1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/w4/cw4fsmrzlz3yxap27ofc3ad2deoo2c6ok5s5m63cejtspevv4x7z.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/w4/f0f39341a364b97de7ebf532befc9ac8dcfd4e6751fa4e8a496d3a6bb04d646f.best_config b/progress/SpecForge/cache/compiled_kernels/w4/f0f39341a364b97de7ebf532befc9ac8dcfd4e6751fa4e8a496d3a6bb04d646f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..40730bb99578020826990ee2fa0abea87202f3e0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/w4/f0f39341a364b97de7ebf532befc9ac8dcfd4e6751fa4e8a496d3a6bb04d646f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 47, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/w6/cw6gm2uslcvxn6iwc5f5uyiyynytygtmmnznph2mhzxeogokdj5b.py b/progress/SpecForge/cache/compiled_kernels/w6/cw6gm2uslcvxn6iwc5f5uyiyynytygtmmnznph2mhzxeogokdj5b.py new file mode 100644 index 0000000000000000000000000000000000000000..472afda4598e2e822a766b7124f0c7a98288b658 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/w6/cw6gm2uslcvxn6iwc5f5uyiyynytygtmmnznph2mhzxeogokdj5b.py @@ -0,0 +1,1018 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zl/czl7xvjffxknhj474g2xpnuucujqtdridm4gjtkxclxkcxeajpzr.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 624, 128][2555904, 79872, 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (624, 624, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 239616, 'r0_': 10223616}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 19968 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 624) + x1 = xindex // 624 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ry/crylsfwy4n57m7yd6unhb67v2ldprgpyuhxft3lhc6xp2xkvgns5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 624, 128][2555904, 79872, 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_8 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (624, 624, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 638976, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 2555904, 79872, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 2555904, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 5 + stride_q_idx_h = 25 + stride_q_idx_n = 5 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 79872*off_hkv + 638976*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 624 + KV_LEN = 624 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_5, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_6, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_7, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_8, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_9, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_10, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_11, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(getitem, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 624), (19968, 624, 1)) + assert_size_stride(tangents_1, (1, 32, 624, 128), (2555904, 79872, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 624), (19968, 624, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf1 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 19968, 128, stream=stream4) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 624, 128), (2555904, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 624, 128), (638976, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 624, 128), (638976, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 25, 1, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 624), (19968, 624, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 624, 128), (2555904, 79872, 128, 1), device='cuda:4', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 624), (19968, 624, 1), device='cuda:4', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/w7/cw7357bop5xftxhhquwtx3icyjabhhvdhitkzz7bxqds4e5bhmzj.py b/progress/SpecForge/cache/compiled_kernels/w7/cw7357bop5xftxhhquwtx3icyjabhhvdhitkzz7bxqds4e5bhmzj.py new file mode 100644 index 0000000000000000000000000000000000000000..0478a2612fe096552605db6ccb231228fb9bb2a6 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/w7/cw7357bop5xftxhhquwtx3icyjabhhvdhitkzz7bxqds4e5bhmzj.py @@ -0,0 +1,866 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/xv/cxv75l5tngryemr2gz2tewmwjhcxtrh7oskhpjurg3h2g73kiw5e.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 48, 128][196608, 128, 4096, 1]cuda:3" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 48, 128][49152, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 48, 128][49152, 128, 1024, 1]cuda:3" = PlaceHolder[target=arg2_1] +# %buf0 : Tensor "f32[1, 32, 32, 48][49152, 1536, 48, 1]cuda:3" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, 48][49152, 1536, 48, 1]cuda:3" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:3" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:3" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (48, 48, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 196608, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 49152, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 49152, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 49152, 1536, 48, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 49152, 1536, 48, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = 48 + KV_LEN = 48 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t + 6291456*idx_z + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/r5/cr5gimrtqf5gqmgkgirz6nebp4ijwtbzhmqlsdgjqda33dr43qre.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, 48][1536, 1536, 48, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, 48][1536, 48, 1]cuda:3" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (48, 48, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %mul : Tensor "f32[1, 32, 48][1536, 48, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%mul +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 36864, 'r0_': 0}} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 1536 + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 1536*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 1536*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x0), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ee/ceeui5pfegm2pcpkokzghtbznstx2p5iqb3n6wdv2ci43lzhlejq.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, 48, 128][6291456, 196608, 6144, 128, 1]cuda:3" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, 48][1536, 1536, 48, 1]cuda:3" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, 48, 128][196608, 6144, 128, 1]cuda:3" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, 48][1536, 48, 1]cuda:3" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (48, 48, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, 48, 128][196608, 128, 4096, 1]cuda:3"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 262144, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 25952256, 'r0_': 0}} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 196608 + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % 48) + x4 = xindex // 6144 + tmp0 = tl.load(in_ptr0 + (x5 + 196608*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 1536*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 48, 128), (196608, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 48, 128), (49152, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 48, 128), (49152, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg4_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg5_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg7_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg8_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, 32, 48), (49152, 1536, 48, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, 48), (49152, 1536, 48, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, 48, 128), (6291456, 196608, 6144, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 8, 32, 1, stream=stream3) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = empty_strided_cuda((1, 1, 32, 48), (1536, 1536, 48, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, 48), (1536, 48, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, 48), (1536, 48, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + stream3 = get_raw_stream(3) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, 1536, 32, stream=stream3) + del buf1 + buf9 = empty_strided_cuda((1, 32, 48, 128), (196608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, 196608, 32, stream=stream3) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:3', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/w7/cw7v36vovuusyu5wa4jonx4xwdt2temsvyczysvmx2invlw2vgjl.py b/progress/SpecForge/cache/compiled_kernels/w7/cw7v36vovuusyu5wa4jonx4xwdt2temsvyczysvmx2invlw2vgjl.py new file mode 100644 index 0000000000000000000000000000000000000000..956cd8148cc05238705e68848405212407e4d312 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/w7/cw7v36vovuusyu5wa4jonx4xwdt2temsvyczysvmx2invlw2vgjl.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/wa/cwa3d7u6v26iskg2c3qsu77ukwhwjvz4nhvymqcjq5rueezli4mf.py b/progress/SpecForge/cache/compiled_kernels/wa/cwa3d7u6v26iskg2c3qsu77ukwhwjvz4nhvymqcjq5rueezli4mf.py new file mode 100644 index 0000000000000000000000000000000000000000..0664411b13983aeb31b199d93f1b529b53605dc5 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wa/cwa3d7u6v26iskg2c3qsu77ukwhwjvz4nhvymqcjq5rueezli4mf.py @@ -0,0 +1,731 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/fs/cfs2of6ygpacl7vcs74vy37xwnw4avrqcbqon6ht56dira6khkkl.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:2" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_18] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/zi/czisxx2rasv3jxm6nvp2yhkwlbob3ysmyfupuctvxqfwp52ekm3f.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s84 = primals_21 + s53 = primals_22 + s100 = primals_24 + s5 = primals_26 + s10 = primals_27 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_15, primals_18, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream2) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream2) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, buf2, buf0, s37, s0, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1600 + primals_2 = rand_strided((1, 32, 1600, 128), (6553600, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = 1600 + primals_4 = rand_strided((1, 8, 1600, 128), (1638400, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_5 = 1600 + primals_6 = rand_strided((1, 8, 1600, 128), (1638400, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = 13 + primals_8 = 13 + primals_9 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_10 = 1600 + primals_11 = 1600 + primals_12 = 13 + primals_13 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_14 = 13 + primals_15 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_16 = 13 + primals_17 = 13 + primals_18 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_19 = 13 + primals_20 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_21 = 13 + primals_22 = 13 + primals_23 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + primals_24 = 13 + primals_25 = rand_strided((1, 1, 13), (13, 13, 1), device='cuda:2', dtype=torch.int32) + primals_26 = 13 + primals_27 = 13 + primals_28 = rand_strided((1, 1, 13, 13), (169, 169, 13, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/wa/cwagxen2jq5ds6xroa7i3jzlbijcannmnasinwrxzd7w5hx7yfaj.py b/progress/SpecForge/cache/compiled_kernels/wa/cwagxen2jq5ds6xroa7i3jzlbijcannmnasinwrxzd7w5hx7yfaj.py new file mode 100644 index 0000000000000000000000000000000000000000..e6fdb17cc7de61b299606b87de0bfbccbed14863 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wa/cwagxen2jq5ds6xroa7i3jzlbijcannmnasinwrxzd7w5hx7yfaj.py @@ -0,0 +1,707 @@ +# AOT ID: ['5_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/jw/cjwtqvxssq5bhbgq4dcc3a5swgw24a7smwlcuuhvmz7ctbt4rdza.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/5q/c5qxm76eg6wfczhqmdc3gfksgjzrbgtma7q4i4gsyz7dq3yw7ikj.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, (127 + s37) // 128, 1, 32, stream=stream6) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream6) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf2, buf0, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 128 + primals_2 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = 128 + primals_4 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = 128 + primals_6 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_8 = 128 + primals_9 = 128 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/wc/587d36217871077a029f489140670c9a88c47fad84b833effc4bb72273ef3a76.best_config b/progress/SpecForge/cache/compiled_kernels/wc/587d36217871077a029f489140670c9a88c47fad84b833effc4bb72273ef3a76.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d3d4f60b2fd8ff56129ebb8dae443a6b5220301b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wc/587d36217871077a029f489140670c9a88c47fad84b833effc4bb72273ef3a76.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 14, "triton_cache_hash": "GWPHKQM6FYHNRATK3VNSWNEFGZENRP4ZFLSI7WCAA73775CHNCPQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/wc/cwcq2mf7cnltkrq7krmalipwcxj2b46f7bdyonvnqsqw5uhfomts.py b/progress/SpecForge/cache/compiled_kernels/wc/cwcq2mf7cnltkrq7krmalipwcxj2b46f7bdyonvnqsqw5uhfomts.py new file mode 100644 index 0000000000000000000000000000000000000000..5711f984c0859de8bbb13cbe543257e0504e30d3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wc/cwcq2mf7cnltkrq7krmalipwcxj2b46f7bdyonvnqsqw5uhfomts.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 167936}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 20992 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/wd/cwd5z67eoub753kyoyz427o4lmivpjimmoxxflrmq35ty4kui5ts.py b/progress/SpecForge/cache/compiled_kernels/wd/cwd5z67eoub753kyoyz427o4lmivpjimmoxxflrmq35ty4kui5ts.py new file mode 100644 index 0000000000000000000000000000000000000000..33cf5d931319364fe9b6a1b9dda8bdab0639df1b --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wd/cwd5z67eoub753kyoyz427o4lmivpjimmoxxflrmq35ty4kui5ts.py @@ -0,0 +1,1018 @@ +# AOT ID: ['0_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, 976, 128][3997696, 124928, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=tangents_2] +# %mul_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (976, 976, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 374784, 'r0_': 15990784}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 31232 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 976) + x1 = xindex // 976 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/om/comxwt7qmz7r76sshdlb7eiymhymj24rvyjhgcvd4d54rii5xcxm.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, 976, 128][3997696, 124928, 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, 976, 128][3997696, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, 976, 128][999424, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_4 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_5 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_5] +# %primals_8 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_8] +# %primals_9 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_6 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_10 : Tensor "i32[1, 1, 8][8, 8, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_11 : Tensor "i32[1, 1, 8, 8][64, 64, 8, 1]cuda:7" = PlaceHolder[target=primals_11] +# %mul_1 : Tensor "f32[1, 32, 976][31232, 976, 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %mul_1, %fw_graph0, %joint_graph0, (976, 976, %primals_4, %primals_5, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 3997696, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 999424, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 999424, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 3997696, 124928, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 3997696, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 999424, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = 976 + ZKV = 1 + KV_LEN = 976 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 8 + stride_kv_idx_h = 64 + stride_kv_idx_m = 8 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 8 + stride_q_idx_h = 64 + stride_q_idx_n = 8 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 124928*off_hkv + 999424*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 976 + KV_LEN = 976 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 976, 128), (999424, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_5, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_6, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_7, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_8, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_9, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(primals_10, (1, 1, 8), (8, 8, 1)) + assert_size_stride(primals_11, (1, 1, 8, 8), (64, 64, 8, 1)) + assert_size_stride(getitem, (1, 32, 976, 128), (3997696, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, 976), (31232, 976, 1)) + assert_size_stride(tangents_1, (1, 32, 976, 128), (3997696, 124928, 128, 1)) + assert_size_stride(tangents_2, (1, 32, 976), (31232, 976, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((1, 32, 976), (31232, 976, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, 31232, 128, stream=stream7) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, 976, 128), (3997696, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, 976, 128), (999424, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, 976, 128), (999424, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_tem_fused_mul_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_4, primals_5, primals_8, primals_9, primals_6, primals_7, primals_10, primals_11, buf5, 40, 1, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 976, 128), (999424, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 8), (8, 8, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 8, 8), (64, 64, 8, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((1, 32, 976, 128), (3997696, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 976), (31232, 976, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 976, 128), (3997696, 124928, 128, 1), device='cuda:7', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 976), (31232, 976, 1), device='cuda:7', dtype=torch.float32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/wd/cwd6s5yxcb5l6onixb7a5phqy62eimthnyhapdjvvmvivprxr2t5.py b/progress/SpecForge/cache/compiled_kernels/wd/cwd6s5yxcb5l6onixb7a5phqy62eimthnyhapdjvvmvivprxr2t5.py new file mode 100644 index 0000000000000000000000000000000000000000..318eaa05cca8a6656a0fe2fdb6281dacaa679132 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wd/cwd6s5yxcb5l6onixb7a5phqy62eimthnyhapdjvvmvivprxr2t5.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/wg/cwgwwzr7odt2374bwa4hc5trgmiyoenh7xskuj3j673yf3oyggqg.py b/progress/SpecForge/cache/compiled_kernels/wg/cwgwwzr7odt2374bwa4hc5trgmiyoenh7xskuj3j673yf3oyggqg.py new file mode 100644 index 0000000000000000000000000000000000000000..76777c801e1c97cb5fedc43696992d0c5e7e89eb --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wg/cwgwwzr7odt2374bwa4hc5trgmiyoenh7xskuj3j673yf3oyggqg.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 11 + stride_q_idx_h = 121 + stride_q_idx_n = 11 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/wi/cwiezzipnqllejaj4y42z2qr2s4hzhrpbkyyp4nv3hkr57gw3upw.py b/progress/SpecForge/cache/compiled_kernels/wi/cwiezzipnqllejaj4y42z2qr2s4hzhrpbkyyp4nv3hkr57gw3upw.py new file mode 100644 index 0000000000000000000000000000000000000000..772496861fbf3e94231a12f8537a14d84f4d9e4e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wi/cwiezzipnqllejaj4y42z2qr2s4hzhrpbkyyp4nv3hkr57gw3upw.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/wi/cwiof4iou6xjj5lsqeya5ndeadiieeztoux6qjj6lgcomom2mhjv.py b/progress/SpecForge/cache/compiled_kernels/wi/cwiof4iou6xjj5lsqeya5ndeadiieeztoux6qjj6lgcomom2mhjv.py new file mode 100644 index 0000000000000000000000000000000000000000..d4367b55eb7e9801710a7b4d85aee7b5c7a53ecd --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wi/cwiof4iou6xjj5lsqeya5ndeadiieeztoux6qjj6lgcomom2mhjv.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks3 + stride_kv_idx_h = ks4*ks5 + stride_kv_idx_m = ks5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/wq/cwqqg2cesfh4pqmcqtr7fhe7os3hpaefbyse4lnx7tltftkifqap.py b/progress/SpecForge/cache/compiled_kernels/wq/cwqqg2cesfh4pqmcqtr7fhe7os3hpaefbyse4lnx7tltftkifqap.py new file mode 100644 index 0000000000000000000000000000000000000000..937018c13156ab86b8f3d262d00ef99422d6df1d --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/wq/cwqqg2cesfh4pqmcqtr7fhe7os3hpaefbyse4lnx7tltftkifqap.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sj/csjgoierbyb7y37xze3raek6tsz2nam2dgfupz3aqsd5fa3pyzto.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:3" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:3" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:3" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:3" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 16][16, 16, 1]cuda:3" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 16, 16][256, 256, 16, 1]cuda:3" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yb/cybuifkbme5pqec7hqplzgosjwtmpxbo3ubhvaobchgrztupyhkd.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 16), (16, 16, 1)) + assert_size_stride(primals_15, (1, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_16, (1, 1, 16), (16, 16, 1)) + assert_size_stride(primals_17, (1, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_18, (1, 1, 16), (16, 16, 1)) + assert_size_stride(primals_19, (1, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream3 = get_raw_stream(3) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream3) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream3 = get_raw_stream(3) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream3) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = rand_strided((1, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_3 = 2048 + primals_4 = rand_strided((1, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = 16 + primals_8 = 16 + primals_9 = rand_strided((1, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_10 = 2048 + primals_11 = 2048 + primals_12 = 16 + primals_13 = rand_strided((1, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 16), (16, 16, 1), device='cuda:3', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 16, 16), (256, 256, 16, 1), device='cuda:3', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/x4/460272e788f974ed79e703e806eff2283d1de9b894d808de98fc3af38dc9a2d1.best_config b/progress/SpecForge/cache/compiled_kernels/x4/460272e788f974ed79e703e806eff2283d1de9b894d808de98fc3af38dc9a2d1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..d131a13c7786a0afd9ad24a5628260bf7969ba42 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/x4/460272e788f974ed79e703e806eff2283d1de9b894d808de98fc3af38dc9a2d1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 60, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/x4/cx4s7bbqf3jhjjzv3khgpg2sj5bxo6jayddibxat353olmtvrkps.py b/progress/SpecForge/cache/compiled_kernels/x4/cx4s7bbqf3jhjjzv3khgpg2sj5bxo6jayddibxat353olmtvrkps.py new file mode 100644 index 0000000000000000000000000000000000000000..01f189b73222772ff813f08080f27c3075f8d228 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/x4/cx4s7bbqf3jhjjzv3khgpg2sj5bxo6jayddibxat353olmtvrkps.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/x7/cx7zpedaas52tqlacuziqx2yyb6ksu5fcs3viesbzxnjtbw5uner.py b/progress/SpecForge/cache/compiled_kernels/x7/cx7zpedaas52tqlacuziqx2yyb6ksu5fcs3viesbzxnjtbw5uner.py new file mode 100644 index 0000000000000000000000000000000000000000..94fa9c2cfbe84ab77927c25a1d6d57dd5b9affe0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/x7/cx7zpedaas52tqlacuziqx2yyb6ksu5fcs3viesbzxnjtbw5uner.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/xf/872945ca0e2a2a09147289fc9f490f5985db5b5a8cdf756257cf52a04c5639cf.best_config b/progress/SpecForge/cache/compiled_kernels/xf/872945ca0e2a2a09147289fc9f490f5985db5b5a8cdf756257cf52a04c5639cf.best_config new file mode 100644 index 0000000000000000000000000000000000000000..6d223c32ca881d0bdba67ab08030e3c79433debc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xf/872945ca0e2a2a09147289fc9f490f5985db5b5a8cdf756257cf52a04c5639cf.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 71, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/xf/cxfq6ic3efg6yukl55sige4jtfyzy5dlspo6ipsr5mew7q7i32dg.py b/progress/SpecForge/cache/compiled_kernels/xf/cxfq6ic3efg6yukl55sige4jtfyzy5dlspo6ipsr5mew7q7i32dg.py new file mode 100644 index 0000000000000000000000000000000000000000..fb392db324e36181ec8f5acbf29821842b7ce255 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xf/cxfq6ic3efg6yukl55sige4jtfyzy5dlspo6ipsr5mew7q7i32dg.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/xg/2ed1a22f6f43a35785119e255e86bda14658c161a492d8e2f79d1894cd624dfa.best_config b/progress/SpecForge/cache/compiled_kernels/xg/2ed1a22f6f43a35785119e255e86bda14658c161a492d8e2f79d1894cd624dfa.best_config new file mode 100644 index 0000000000000000000000000000000000000000..8d72fa94b62edd82d44331a4036ade0e10e18f2f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xg/2ed1a22f6f43a35785119e255e86bda14658c161a492d8e2f79d1894cd624dfa.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 8, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 108, "triton_cache_hash": "XRPIXE6422Z3WVFKM6FTH3VU3RBLBAM5QFGQDRDJKHCOAJAWTZHQ"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/xg/cxg4rooeabvsbvmyiis4wo5wrcdia6pxyhr5tm3rece3ty6f3zys.py b/progress/SpecForge/cache/compiled_kernels/xg/cxg4rooeabvsbvmyiis4wo5wrcdia6pxyhr5tm3rece3ty6f3zys.py new file mode 100644 index 0000000000000000000000000000000000000000..dca8586a7c669521ab3cdefefae67f44639fd757 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xg/cxg4rooeabvsbvmyiis4wo5wrcdia6pxyhr5tm3rece3ty6f3zys.py @@ -0,0 +1,54 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) diff --git a/progress/SpecForge/cache/compiled_kernels/xi/cxixit3txgzdbhf5t7iy5blh5uin52vrlk3ivqurfua234pc4kl3.py b/progress/SpecForge/cache/compiled_kernels/xi/cxixit3txgzdbhf5t7iy5blh5uin52vrlk3ivqurfua234pc4kl3.py new file mode 100644 index 0000000000000000000000000000000000000000..84b2ce0a2b281c0623c3cd8d952df4dca60bf470 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xi/cxixit3txgzdbhf5t7iy5blh5uin52vrlk3ivqurfua234pc4kl3.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/xk/3c2a79223a2cd1e15f76e12cce20cb10f03a6234e706aac9416d0a8741a101a2.best_config b/progress/SpecForge/cache/compiled_kernels/xk/3c2a79223a2cd1e15f76e12cce20cb10f03a6234e706aac9416d0a8741a101a2.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xk/3c2a79223a2cd1e15f76e12cce20cb10f03a6234e706aac9416d0a8741a101a2.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/xk/cxkaggjlxk3uyhnwwv4cbpmjoszxdy4lqcdib4aaljgij4vgm32o.py b/progress/SpecForge/cache/compiled_kernels/xk/cxkaggjlxk3uyhnwwv4cbpmjoszxdy4lqcdib4aaljgij4vgm32o.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad145f151e3773dd0d4edc2cd59114ee37d08e4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xk/cxkaggjlxk3uyhnwwv4cbpmjoszxdy4lqcdib4aaljgij4vgm32o.py @@ -0,0 +1,1028 @@ +# AOT ID: ['1_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/js/cjshj4sn3r4i5oztbzd5rsp5cgfgoeb5sxl7zfjsj7pboonvwjbu.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:7" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:7" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:7" = PlaceHolder[target=primals_9] +# %primals_16 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:7" = PlaceHolder[target=primals_16] +# %primals_17 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:7" = PlaceHolder[target=primals_17] +# %primals_14 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:7" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:7" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, 11][11, 11, 1]cuda:7" = PlaceHolder[target=primals_18] +# %primals_19 : Tensor "i32[1, 1, 11, 11][121, 121, 11, 1]cuda:7" = PlaceHolder[target=primals_19] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 11 + stride_q_idx_h = 121 + stride_q_idx_n = 11 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_15, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_16, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_17, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(primals_18, (1, 1, 11), (11, 11, 1)) + assert_size_stride(primals_19, (1, 1, 11, 11), (121, 121, 11, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream7) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream7 = get_raw_stream(7) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_16, primals_17, primals_14, primals_15, primals_18, primals_19, buf5, s37, s0, s99, s22, s72, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream7) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_17 + del primals_18 + del primals_19 + del primals_2 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1328 + primals_11 = 1328 + primals_7 = 11 + primals_8 = 11 + primals_12 = 11 + primals_2 = rand_strided((1, 32, 1328, 128), (5439488, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 1328, 128), (1359872, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 1328, 128), (1359872, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:7', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:7', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:7', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:7', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 11), (11, 11, 1), device='cuda:7', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 11, 11), (121, 121, 11, 1), device='cuda:7', dtype=torch.int32) + getitem = rand_strided((1, 32, 1328, 128), (5439488, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 1328), (42496, 1328, 1), device='cuda:7', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 1328, 128), (5439488, 169984, 128, 1), device='cuda:7', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 1328), (42496, 1328, 1), device='cuda:7', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/xk/cxktzujto4ftosl5gghkudqc4rbz3ix4zcxmbpo5ijvvteuoihei.py b/progress/SpecForge/cache/compiled_kernels/xk/cxktzujto4ftosl5gghkudqc4rbz3ix4zcxmbpo5ijvvteuoihei.py new file mode 100644 index 0000000000000000000000000000000000000000..8903c44567c122df2382fc15e49751a3a07da45f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xk/cxktzujto4ftosl5gghkudqc4rbz3ix4zcxmbpo5ijvvteuoihei.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/xq/cxqc7mlw6256g3dcu6mnc6zedujssjsu3g736logtr7zcouuboiv.py b/progress/SpecForge/cache/compiled_kernels/xq/cxqc7mlw6256g3dcu6mnc6zedujssjsu3g736logtr7zcouuboiv.py new file mode 100644 index 0000000000000000000000000000000000000000..9228295c4e4a5ddc433d4bd30fc0783866f644e7 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xq/cxqc7mlw6256g3dcu6mnc6zedujssjsu3g736logtr7zcouuboiv.py @@ -0,0 +1,694 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=primals_6] +# %primals_7 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=primals_7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (624, 624, %primals_5, %primals_4, %primals_6, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 79872*idx_hq + 2555904*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/4a/c4anvofd7zpauvjftktjqqrtqgoz7sd5ckg3wf7op3c7cmnbrzz6.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 159744}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 19968 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11 = args + args.clear() + assert_size_stride(primals_1, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(primals_2, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_3, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(primals_4, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_5, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_6, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_7, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_8, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_9, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(primals_10, (1, 1, 5), (5, 5, 1)) + assert_size_stride(primals_11, (1, 1, 5, 5), (25, 25, 5, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 624, 128), (2555904, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_6, primals_7, buf2, 5, 1, 32, stream=stream4) + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, 19968, stream=stream4) + return (buf2, buf5, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_7 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/xv/cxv75l5tngryemr2gz2tewmwjhcxtrh7oskhpjurg3h2g73kiw5e.py b/progress/SpecForge/cache/compiled_kernels/xv/cxv75l5tngryemr2gz2tewmwjhcxtrh7oskhpjurg3h2g73kiw5e.py new file mode 100644 index 0000000000000000000000000000000000000000..288e6178ba22ec5d58cbe15ef22c4622ab1ede04 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xv/cxv75l5tngryemr2gz2tewmwjhcxtrh7oskhpjurg3h2g73kiw5e.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 196608, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 49152, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 49152, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 49152, 1536, 48, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 49152, 1536, 48, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = 48 + KV_LEN = 48 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t + 6291456*idx_z + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 6144*idx_hq + 196608*idx_t, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i diff --git a/progress/SpecForge/cache/compiled_kernels/xz/cxzfcp7h25xdy6mvnsf7rcwqyukbgu2thkfn7qmwcn2fujxft5x5.py b/progress/SpecForge/cache/compiled_kernels/xz/cxzfcp7h25xdy6mvnsf7rcwqyukbgu2thkfn7qmwcn2fujxft5x5.py new file mode 100644 index 0000000000000000000000000000000000000000..584cf0aa244f54bfed645f2d0535419e0d290471 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/xz/cxzfcp7h25xdy6mvnsf7rcwqyukbgu2thkfn7qmwcn2fujxft5x5.py @@ -0,0 +1,1019 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/iy/ciyksx73miv4q3buobp6lpr65funn6ikpjyesnfdvjt7n2v4tfmd.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_per_fused_mul_0 = async_compile.triton('triton_per_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 128}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 128 + R0_BLOCK: tl.constexpr = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), xmask, other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, other=0.0).to(tl.float32) + tmp8 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = tl.where(xmask, tmp3, 0) + tmp6 = tl.sum(tmp5, 1)[:, None].to(tl.float32) + tmp7 = tmp6.to(tl.float32) + tmp9 = 0.6931471805599453 + tmp10 = tmp8 * tmp9 + tmp11 = 1.4426950408889634 + tmp12 = tmp10 * tmp11 + tmp13 = tmp7 - tmp12 + tl.store(out_ptr1 + (x3), tmp13, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/3u/c3umapah7vcozhvfk5uovlssor7v533y4crphqgd677nuoizbpvj.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_13 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_14 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_14] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_12] +# %primals_15 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:2" = PlaceHolder[target=primals_15] +# %primals_16 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:2" = PlaceHolder[target=primals_16] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_8 + s0 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_per_fused_mul_0_xnumel = 32*s37 + stream2 = get_raw_stream(2) + triton_per_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_per_fused_mul_0_xnumel, 128, stream=stream2) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream2 = get_raw_stream(2) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_10, primals_7, primals_13, primals_14, primals_11, primals_12, primals_15, primals_16, buf5, s37, s0, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_10 + del primals_11 + del primals_12 + del primals_13 + del primals_14 + del primals_15 + del primals_16 + del primals_2 + del primals_4 + del primals_6 + del primals_7 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 96 + primals_9 = 96 + primals_2 = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 96, 128), (98304, 128, 1024, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:2', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((1, 32, 96, 128), (393216, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 96), (3072, 96, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 96, 128), (393216, 12288, 128, 1), device='cuda:2', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 96), (3072, 96, 1), device='cuda:2', dtype=torch.float32) + fn = lambda: call([primals_8, primals_9, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/y2/cy2vkrqd6v6gmwfkzbarwszd4anv5zbramut3pnxjch62thk6us7.py b/progress/SpecForge/cache/compiled_kernels/y2/cy2vkrqd6v6gmwfkzbarwszd4anv5zbramut3pnxjch62thk6us7.py new file mode 100644 index 0000000000000000000000000000000000000000..10ef84c8fba46606d3333de2e258a0c9034fe778 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/y2/cy2vkrqd6v6gmwfkzbarwszd4anv5zbramut3pnxjch62thk6us7.py @@ -0,0 +1,876 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tl/ctlpu6dtplhnxq223amgvjrfwcqysnrjubory4o2ult7mshpvgyw.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:7" = PlaceHolder[target=primals_6] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:7" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:7" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/nl/cnlicf3gxmwqxgcrephkgh25nolfnuz7cungpmokpzy6oauhzf2v.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem_1 +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7" = PlaceHolder[target=getitem_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=2] = call_function[target=operator.getitem](args = (%flex_attention, 1), kwargs = {}) +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:7"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%getitem_1,%mul_15 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/pj/cpjqvrivm5zoy63twn7vyb2bstmqe7t47pwzubekts7kz4aagkyy.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:7" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:7" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:7"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 524288, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(7): + torch.cuda.set_device(7) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream7 = get_raw_stream(7) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, 8, 32, 1, stream=stream7) + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + buf11 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream7 = get_raw_stream(7) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, buf11, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream7) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream7 = get_raw_stream(7) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream7) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf11, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf9, buf10, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 80 + primals_2 = rand_strided((1, 32, 80, 128), (327680, 128, 4096, 1), device='cuda:7', dtype=torch.bfloat16) + primals_3 = 80 + primals_4 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_5 = 80 + primals_6 = rand_strided((1, 8, 80, 128), (81920, 128, 1024, 1), device='cuda:7', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_8 = 80 + primals_9 = 80 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:7', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:7', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/y3/cy3tcjz7bumb36zu5ij2k4rw5jslf3pn2hm4m4vdngje5z4s24an.py b/progress/SpecForge/cache/compiled_kernels/y3/cy3tcjz7bumb36zu5ij2k4rw5jslf3pn2hm4m4vdngje5z4s24an.py new file mode 100644 index 0000000000000000000000000000000000000000..8824f2d9ee19bf72429dac3c1d77c62f1d88d570 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/y3/cy3tcjz7bumb36zu5ij2k4rw5jslf3pn2hm4m4vdngje5z4s24an.py @@ -0,0 +1,715 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/da/cdasofnd3pbbt2ztadrt7liw6r5jwtyuf5j5jlt43d4d6wrqqs5w.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg1_1 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %arg3_1 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %arg5_1 : Tensor "bf16[1, 8, s43, 128][1024*s43, 128, 1024, 1]cuda:0" = PlaceHolder[target=arg5_1] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %arg9_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=arg9_1] +# %arg6_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=arg6_1] +# %arg10_1 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:0" = PlaceHolder[target=arg10_1] +# %arg11_1 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:0" = PlaceHolder[target=arg11_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg1_1, %arg3_1, %arg5_1, %sdpa_score0, (%arg7_1, %arg8_1, %arg9_1, %arg6_1, %arg10_1, %arg11_1, %arg12_1, %arg13_1, %arg14_1, %arg15_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/fx/cfxwr3zeufs4zyj7lcrehue35sjnqfjexlubedkegfvtfkzuyx4y.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_9 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_9 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_9 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1 = args + args.clear() + s50 = arg0_1 + s0 = arg2_1 + s43 = arg4_1 + s37 = arg7_1 + s71 = arg8_1 + assert_size_stride(arg1_1, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(arg3_1, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(arg5_1, (1, 8, s43, 128), (1024*s43, 128, 1024, 1)) + assert_size_stride(arg6_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg9_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg10_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg11_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg12_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg13_1, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(arg14_1, (1, 1, 1), (1, 1, 1)) + assert_size_stride(arg15_1, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(arg1_1, arg3_1, arg5_1, buf0, buf1, arg9_1, arg6_1, arg10_1, arg11_1, buf2, s37, s0, s43, (127 + s37) // 128, 1, 32, stream=stream0) + del arg10_1 + del arg11_1 + del arg1_1 + del arg3_1 + del arg5_1 + del arg6_1 + del arg9_1 + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 128 + arg1_1 = rand_strided((1, 32, 128, 128), (524288, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = 128 + arg3_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg4_1 = 128 + arg5_1 = rand_strided((1, 8, 128, 128), (131072, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + arg6_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg7_1 = 128 + arg8_1 = 128 + arg9_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg11_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg12_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg13_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + arg14_1 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:0', dtype=torch.int32) + arg15_1 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/y3/cy3vljzbimqjjou6kr37hbpukya23hax4dfmxpaz57r25iaadf5c.py b/progress/SpecForge/cache/compiled_kernels/y3/cy3vljzbimqjjou6kr37hbpukya23hax4dfmxpaz57r25iaadf5c.py new file mode 100644 index 0000000000000000000000000000000000000000..0465b317a196086f69fc931190750afa6ead7c69 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/y3/cy3vljzbimqjjou6kr37hbpukya23hax4dfmxpaz57r25iaadf5c.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/y5/839840f95a527c91b1ea2ac44b4e4cd757345edeb6a3b43da306d00cbb403c83.best_config b/progress/SpecForge/cache/compiled_kernels/y5/839840f95a527c91b1ea2ac44b4e4cd757345edeb6a3b43da306d00cbb403c83.best_config new file mode 100644 index 0000000000000000000000000000000000000000..95c33e2998cf8060ea61b42b0c24eeca276991b2 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/y5/839840f95a527c91b1ea2ac44b4e4cd757345edeb6a3b43da306d00cbb403c83.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "EZ3RKHM23IARJ6OLUDAWBKS54ORODGZICEVX6ZI5AEYV54IQCCLA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/y5/cy5jljxgujndttzwbv4ilh6rxm6speqtj427ms2af3tgcjqi4rts.py b/progress/SpecForge/cache/compiled_kernels/y5/cy5jljxgujndttzwbv4ilh6rxm6speqtj427ms2af3tgcjqi4rts.py new file mode 100644 index 0000000000000000000000000000000000000000..61d0cc053ea99b57cbcd417405798873bf28680e --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/y5/cy5jljxgujndttzwbv4ilh6rxm6speqtj427ms2af3tgcjqi4rts.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 483328}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 60416 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/ya/cyaaqpkxycl2udxe7g4rziimyij5a4u7u2ix5uwl6bovk3ugnbze.py b/progress/SpecForge/cache/compiled_kernels/ya/cyaaqpkxycl2udxe7g4rziimyij5a4u7u2ix5uwl6bovk3ugnbze.py new file mode 100644 index 0000000000000000000000000000000000000000..884650a4d98341a7081c2595a3dd4151fdebffbc --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ya/cyaaqpkxycl2udxe7g4rziimyij5a4u7u2ix5uwl6bovk3ugnbze.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention_backward(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 9 + stride_q_idx_h = 81 + stride_q_idx_n = 9 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ya/cyajcmmu43wsdskmmi3lnbow6xhgke4ozgqvwrn72tl3b2fss624.py b/progress/SpecForge/cache/compiled_kernels/ya/cyajcmmu43wsdskmmi3lnbow6xhgke4ozgqvwrn72tl3b2fss624.py new file mode 100644 index 0000000000000000000000000000000000000000..60e1c0203ada9e4b2cf4d4ce625581b0bc97f2f3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ya/cyajcmmu43wsdskmmi3lnbow6xhgke4ozgqvwrn72tl3b2fss624.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7j/c7jy7orghqlzyryav6xs6tfj7jilalobanvnpyxyky7yueuhrwlx.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 6][6, 6, 1]cuda:0" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 6, 6][36, 36, 6, 1]cuda:0" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/22/c225p5q54jhc2rfoccuzlgejscvq2in5jzxlzcilu44cplhbfreo.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_15, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_16, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_17, (1, 1, 6, 6), (36, 36, 6, 1)) + assert_size_stride(primals_18, (1, 1, 6), (6, 6, 1)) + assert_size_stride(primals_19, (1, 1, 6, 6), (36, 36, 6, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream0) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream0) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 720 + primals_2 = rand_strided((1, 32, 720, 128), (2949120, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = 720 + primals_4 = rand_strided((1, 8, 720, 128), (737280, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 720 + primals_6 = rand_strided((1, 8, 720, 128), (737280, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 6 + primals_8 = 6 + primals_9 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_10 = 720 + primals_11 = 720 + primals_12 = 6 + primals_13 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 6), (6, 6, 1), device='cuda:0', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 6, 6), (36, 36, 6, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/yb/439a3ffb81622750326114f99c79de9c757682fcde216b521af8c1a20635ade1.best_config b/progress/SpecForge/cache/compiled_kernels/yb/439a3ffb81622750326114f99c79de9c757682fcde216b521af8c1a20635ade1.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3514cd4b7efbcfec5a9027a69143f1b0c3ed176a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yb/439a3ffb81622750326114f99c79de9c757682fcde216b521af8c1a20635ade1.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yb/cybd7xjn6u3nwakcjl5vtjx3jiyvymn4ksil4sdb23q4woszkxyb.py b/progress/SpecForge/cache/compiled_kernels/yb/cybd7xjn6u3nwakcjl5vtjx3jiyvymn4ksil4sdb23q4woszkxyb.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cdec619512e0adf82d047c7372a538abe07da3 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yb/cybd7xjn6u3nwakcjl5vtjx3jiyvymn4ksil4sdb23q4woszkxyb.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/am/camqma4hqvmxziggaorqeiiobjciqtakh45i2fu3p47njnbhow24.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 1824, 128][7471104, 128, 4096, 1]cuda:4" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 1824, 128][1867776, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 1824, 128][1867776, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 1824][58368, 1824, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 1824][58368, 1824, 1]cuda:4" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:4" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 15][15, 15, 1]cuda:4" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 15, 15][225, 225, 15, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (1824, 1824, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 7471104, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1867776, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1867776, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 1824 + ZKV = 1 + KV_LEN = 1824 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 15 + stride_kv_idx_h = 225 + stride_kv_idx_m = 15 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 233472*idx_hq + 7471104*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/yc/cych2j34c5zq6tvwqfjomrwqs66nalytep4dufch6knmdh36bdb3.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 1824][58368, 1824, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 466944}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 58368 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 1824, 128), (7471104, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 1824, 128), (1867776, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 1824, 128), (1867776, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 15), (15, 15, 1)) + assert_size_stride(arg4_1, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(arg5_1, (1, 1, 15), (15, 15, 1)) + assert_size_stride(arg6_1, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(arg7_1, (1, 1, 15), (15, 15, 1)) + assert_size_stride(arg8_1, (1, 1, 15, 15), (225, 225, 15, 1)) + assert_size_stride(arg9_1, (1, 1, 15), (15, 15, 1)) + assert_size_stride(arg10_1, (1, 1, 15, 15), (225, 225, 15, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, 1824), (58368, 1824, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 1824), (58368, 1824, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 1824, 128), (7471104, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 15, 1, 32, stream=stream4) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, 58368, stream=stream4) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 1824, 128), (7471104, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 1824, 128), (1867776, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 1824, 128), (1867776, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/yb/cybixkf4xfk36h3xorbdpftxrxeuynavnn73zkwli5ahlyodmtkq.py b/progress/SpecForge/cache/compiled_kernels/yb/cybixkf4xfk36h3xorbdpftxrxeuynavnn73zkwli5ahlyodmtkq.py new file mode 100644 index 0000000000000000000000000000000000000000..92fb1364747150f24f16efbe7dd2ec32f2d34059 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yb/cybixkf4xfk36h3xorbdpftxrxeuynavnn73zkwli5ahlyodmtkq.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 9 + stride_q_idx_h = 81 + stride_q_idx_n = 9 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/yb/cybuifkbme5pqec7hqplzgosjwtmpxbo3ubhvaobchgrztupyhkd.py b/progress/SpecForge/cache/compiled_kernels/yb/cybuifkbme5pqec7hqplzgosjwtmpxbo3ubhvaobchgrztupyhkd.py new file mode 100644 index 0000000000000000000000000000000000000000..475c107a0e4a892602b7a49fd37611557ab1445c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yb/cybuifkbme5pqec7hqplzgosjwtmpxbo3ubhvaobchgrztupyhkd.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/yc/277ae35f0105c15f715c94d69af7c3f1a6cd275d2ffff56d3fffbf6cc7cf098c.best_config b/progress/SpecForge/cache/compiled_kernels/yc/277ae35f0105c15f715c94d69af7c3f1a6cd275d2ffff56d3fffbf6cc7cf098c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..bb9d5aac05f4375b023cd636ab8154ecad1aa2f1 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yc/277ae35f0105c15f715c94d69af7c3f1a6cd275d2ffff56d3fffbf6cc7cf098c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 15, "triton_cache_hash": "JQE3T4T3GBYOGTKMMKENLTYUUWIESYRFFDJKZ75LSXOH7Y4E4CJA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yc/cych2j34c5zq6tvwqfjomrwqs66nalytep4dufch6knmdh36bdb3.py b/progress/SpecForge/cache/compiled_kernels/yc/cych2j34c5zq6tvwqfjomrwqs66nalytep4dufch6knmdh36bdb3.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c7afb873cc57ed281eec1287a00594204f8e75 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yc/cych2j34c5zq6tvwqfjomrwqs66nalytep4dufch6knmdh36bdb3.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 466944}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 58368 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/yc/cycraxrzf7rpuoeziiwrpvwb3uz3uqtssnjmydf3g7az5p6q7iir.py b/progress/SpecForge/cache/compiled_kernels/yc/cycraxrzf7rpuoeziiwrpvwb3uz3uqtssnjmydf3g7az5p6q7iir.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a92ec0e183db9bcac7e322c75a560a519a4efd --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yc/cycraxrzf7rpuoeziiwrpvwb3uz3uqtssnjmydf3g7az5p6q7iir.py @@ -0,0 +1,702 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ql/cqlqasxkr2aielad3y2jrg7zqunkz27araoetzma37ddh7wuyyhw.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %arg0_1 : Tensor "bf16[1, 32, 624, 128][2555904, 128, 4096, 1]cuda:4" = PlaceHolder[target=arg0_1] +# %arg1_1 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "bf16[1, 8, 624, 128][638976, 128, 1024, 1]cuda:4" = PlaceHolder[target=arg2_1] +# %getitem_1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4" = PlaceHolder[target=buf1] +# %arg3_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=arg3_1] +# %arg4_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=arg4_1] +# %arg5_1 : Tensor "i32[1, 1, 5][5, 5, 1]cuda:4" = PlaceHolder[target=arg5_1] +# %arg6_1 : Tensor "i32[1, 1, 5, 5][25, 25, 5, 1]cuda:4" = PlaceHolder[target=arg6_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%arg0_1, %arg1_1, %arg2_1, %sdpa_score0, (624, 624, %arg3_1, %arg4_1, %arg5_1, %arg6_1, %arg7_1, %arg8_1, %arg9_1, %arg10_1, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 2555904, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 638976, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 638976, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = 624 + ZKV = 1 + KV_LEN = 624 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 5 + stride_kv_idx_h = 25 + stride_kv_idx_m = 5 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 79872*idx_hq + 2555904*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/rg/crgxcf2kpnvse2bk65ukwreddvsjjcwhijeklqq7egtduqxsqwue.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul : Tensor "f32[1, 32, 624][19968, 624, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 32768}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 159744}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 19968 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args + args.clear() + assert_size_stride(arg0_1, (1, 32, 624, 128), (2555904, 128, 4096, 1)) + assert_size_stride(arg1_1, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(arg2_1, (1, 8, 624, 128), (638976, 128, 1024, 1)) + assert_size_stride(arg3_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg4_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg5_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg6_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg7_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg8_1, (1, 1, 5, 5), (25, 25, 5, 1)) + assert_size_stride(arg9_1, (1, 1, 5), (5, 5, 1)) + assert_size_stride(arg10_1, (1, 1, 5, 5), (25, 25, 5, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 624), (19968, 624, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 624, 128), (2555904, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, buf1, arg3_1, arg4_1, arg5_1, arg6_1, buf2, 5, 1, 32, stream=stream4) + del arg0_1 + del arg1_1 + del arg2_1 + del arg3_1 + del arg4_1 + del arg5_1 + del arg6_1 + buf5 = buf1; del buf1 # reuse + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, 19968, stream=stream4) + del buf0 + return (buf2, buf5, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((1, 32, 624, 128), (2555904, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + arg1_1 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg2_1 = rand_strided((1, 8, 624, 128), (638976, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + arg3_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + arg4_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + arg5_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + arg6_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + arg7_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + arg8_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + arg9_1 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:4', dtype=torch.int32) + arg10_1 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/yf/cyfwrlsuynzuq2tm2hynmskygfkbferzpyblqm4zdolfup4adbmh.py b/progress/SpecForge/cache/compiled_kernels/yf/cyfwrlsuynzuq2tm2hynmskygfkbferzpyblqm4zdolfup4adbmh.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd6ff8c52e7fe67f2499f7a43692bb1e83276a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yf/cyfwrlsuynzuq2tm2hynmskygfkbferzpyblqm4zdolfup4adbmh.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_flex_decoding(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yi/cyig2z2wvhed6r7erirv5om5setgck6uavpfzqyl2k44gk22li5f.py b/progress/SpecForge/cache/compiled_kernels/yi/cyig2z2wvhed6r7erirv5om5setgck6uavpfzqyl2k44gk22li5f.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd1c40df7dc6bc0daf14e5b6e23358f2da95190 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yi/cyig2z2wvhed6r7erirv5om5setgck6uavpfzqyl2k44gk22li5f.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/yk/cyktniqut4hvzwm2zxp6c5v72tensuzn3sghmn25du5w3dywb7be.py b/progress/SpecForge/cache/compiled_kernels/yk/cyktniqut4hvzwm2zxp6c5v72tensuzn3sghmn25du5w3dywb7be.py new file mode 100644 index 0000000000000000000000000000000000000000..f63fc5a45d40618fed9a7f6cefd51154dd72734c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yk/cyktniqut4hvzwm2zxp6c5v72tensuzn3sghmn25du5w3dywb7be.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ym/1485f8e682c7216bb743c177f6d3e6be352db6fa11243c7b38252c0ce969f99c.best_config b/progress/SpecForge/cache/compiled_kernels/ym/1485f8e682c7216bb743c177f6d3e6be352db6fa11243c7b38252c0ce969f99c.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ym/1485f8e682c7216bb743c177f6d3e6be352db6fa11243c7b38252c0ce969f99c.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ym/cymmhz7ggoire32qvayd4g525ooohofhikb6v7n5vukewzni3xbe.py b/progress/SpecForge/cache/compiled_kernels/ym/cymmhz7ggoire32qvayd4g525ooohofhikb6v7n5vukewzni3xbe.py new file mode 100644 index 0000000000000000000000000000000000000000..5134e8b7f8b6aff199b71d64b8dceed6282cf324 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ym/cymmhz7ggoire32qvayd4g525ooohofhikb6v7n5vukewzni3xbe.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py b/progress/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py new file mode 100644 index 0000000000000000000000000000000000000000..2088b28a422c9942d3c8d37b8acb187d20a65b45 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yn/cynm7qmybz3bfizmxg6zv3qhcckex7efjwksscbnyq6ubsjwnpjg.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 1 + stride_kv_idx_h = 1 + stride_kv_idx_m = 1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 1 + stride_q_idx_h = 1 + stride_q_idx_n = 1 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yo/cyovrkn5g6t5dss5uyy7ufwyguhqgajhrqatxilbhyxj6rlek4ae.py b/progress/SpecForge/cache/compiled_kernels/yo/cyovrkn5g6t5dss5uyy7ufwyguhqgajhrqatxilbhyxj6rlek4ae.py new file mode 100644 index 0000000000000000000000000000000000000000..7ffd267f6fe8862c2c3d8e6c20e064226e1027e4 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yo/cyovrkn5g6t5dss5uyy7ufwyguhqgajhrqatxilbhyxj6rlek4ae.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/yo/e232849134005d649049d45c1fec3a64a82f288774853b7872c2e448c229fbbf.best_config b/progress/SpecForge/cache/compiled_kernels/yo/e232849134005d649049d45c1fec3a64a82f288774853b7872c2e448c229fbbf.best_config new file mode 100644 index 0000000000000000000000000000000000000000..f83985899fc91dde26be1835ade2c28101324946 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yo/e232849134005d649049d45c1fec3a64a82f288774853b7872c2e448c229fbbf.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 12, "triton_cache_hash": "2ZIFGDABR2MKMG7ESWF67GBZDP27JEZIQWMBXPOUZFGMG5PW5DSA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yr/002899dc230df6663cc25187fa909ea68fc0510edead336c4f384fe24b81d70a.best_config b/progress/SpecForge/cache/compiled_kernels/yr/002899dc230df6663cc25187fa909ea68fc0510edead336c4f384fe24b81d70a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..8db2bf1768fdc31daee2445f8ed6e932b050cfff --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yr/002899dc230df6663cc25187fa909ea68fc0510edead336c4f384fe24b81d70a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 15, "triton_cache_hash": "HVNEIFRPG5FMYAP4JLC6YMNYKQB4XH24HVAPIISMR2X62PUFTG4Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/yr/cyrolr3isgcukbhhfvhukkk5rnexr7yvexiak2u5gv4at6vg32wy.py b/progress/SpecForge/cache/compiled_kernels/yr/cyrolr3isgcukbhhfvhukkk5rnexr7yvexiak2u5gv4at6vg32wy.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cf05eaa50ffc785503099248d9fb7c082fbb8c --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yr/cyrolr3isgcukbhhfvhukkk5rnexr7yvexiak2u5gv4at6vg32wy.py @@ -0,0 +1,27 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 311296}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 38912 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/yz/cyzfh32o6gsjr3kk4ojomaonyoml7u2mucikoymtxdqhgp6xmhbv.py b/progress/SpecForge/cache/compiled_kernels/yz/cyzfh32o6gsjr3kk4ojomaonyoml7u2mucikoymtxdqhgp6xmhbv.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b4a50ff76d6f162714642d8439e6810caedf93 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yz/cyzfh32o6gsjr3kk4ojomaonyoml7u2mucikoymtxdqhgp6xmhbv.py @@ -0,0 +1,799 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/progress/SpecForge/cache/compiled_kernels/yz/cyzv2l7q32ajyi564ekzfaz763rywqor53ssb7tndlcgxdpuv2ze.py b/progress/SpecForge/cache/compiled_kernels/yz/cyzv2l7q32ajyi564ekzfaz763rywqor53ssb7tndlcgxdpuv2ze.py new file mode 100644 index 0000000000000000000000000000000000000000..e99ec0d87eab61d27f05aa50fa1d68328de4ca31 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/yz/cyzv2l7q32ajyi564ekzfaz763rywqor53ssb7tndlcgxdpuv2ze.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/z6/cz6dmuuvvrgstzh6pvk5sdlmuby5rjlwft7i7nrvo7xiqwex35py.py b/progress/SpecForge/cache/compiled_kernels/z6/cz6dmuuvvrgstzh6pvk5sdlmuby5rjlwft7i7nrvo7xiqwex35py.py new file mode 100644 index 0000000000000000000000000000000000000000..4b16475f79e524f30d22e91e3040eed337e3e932 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/z6/cz6dmuuvvrgstzh6pvk5sdlmuby5rjlwft7i7nrvo7xiqwex35py.py @@ -0,0 +1,534 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_flex_attention(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/za/czavreibvu56vq4htyyxtbexpf3r3xsyhsclf2t4oq5z37c2h7e5.py b/progress/SpecForge/cache/compiled_kernels/za/czavreibvu56vq4htyyxtbexpf3r3xsyhsclf2t4oq5z37c2h7e5.py new file mode 100644 index 0000000000000000000000000000000000000000..dafafeb2268acede2b1dbb76a332f001240d2a5a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/za/czavreibvu56vq4htyyxtbexpf3r3xsyhsclf2t4oq5z37c2h7e5.py @@ -0,0 +1,582 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 512, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks2, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'ieee' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 512 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ze/7fbec52107340bebb95af79eea3d5419a856ef1761f614b28b8182afa45bdda6.best_config b/progress/SpecForge/cache/compiled_kernels/ze/7fbec52107340bebb95af79eea3d5419a856ef1761f614b28b8182afa45bdda6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..c56b3ff6df726fa6b67725165b8989b16d820629 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ze/7fbec52107340bebb95af79eea3d5419a856ef1761f614b28b8182afa45bdda6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 8, "num_warps": 2, "num_stages": 1, "configs_hash": "1542f544a12adfb1397c535fa16687cc79c79a22e4c9cd8af0b373891f747e62", "found_by_coordesc": false, "time_taken_ms": 25, "triton_cache_hash": "TXSYSOZLAKY2QDOCJ7ELQ7KD2AGNEFY3BAYQBACOZ4CYRW5NLH4A"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/ze/czem3fqfwlqohxvpjx3ojg5cfagoqnqmqo7bqx6oqze4sdb2gel5.py b/progress/SpecForge/cache/compiled_kernels/ze/czem3fqfwlqohxvpjx3ojg5cfagoqnqmqo7bqx6oqze4sdb2gel5.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e3d47c7bd4b91ab0e9b1666d59d4d52f6dcbeb --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/ze/czem3fqfwlqohxvpjx3ojg5cfagoqnqmqo7bqx6oqze4sdb2gel5.py @@ -0,0 +1,59 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/zh/czhq74zuiutlf2ryzvnaaaajkil72sm4lia6tjnuf6qwa6ldhqsu.py b/progress/SpecForge/cache/compiled_kernels/zh/czhq74zuiutlf2ryzvnaaaajkil72sm4lia6tjnuf6qwa6ldhqsu.py new file mode 100644 index 0000000000000000000000000000000000000000..9c68d7e16e52bf7a289ed6b92378a6a298c9106a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zh/czhq74zuiutlf2ryzvnaaaajkil72sm4lia6tjnuf6qwa6ldhqsu.py @@ -0,0 +1,1046 @@ +# AOT ID: ['2_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/hk/chkgaqygcy26eau7yvwmovujnogazcg54xtra6atvqxpqdityswt.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf0] +# %tangents_2 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=tangents_2] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_mul_0 = async_compile.triton('triton_red_fused_mul_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32768, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x0 + 128*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/uu/cuuphv47i4uyclwboo4osj65ngxk6w662itmcl6uvwtstfmp6b5w.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[1, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:0" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:0" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_20 : Tensor "i32[1, 1, s56][s56, s56, 1]cuda:0" = PlaceHolder[target=primals_20] +# %primals_23 : Tensor "i32[1, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:0" = PlaceHolder[target=primals_23] +# %primals_15 : Tensor "i32[1, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_15] +# %primals_18 : Tensor "i32[1, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_18] +# %primals_25 : Tensor "i32[1, 1, s100][s100, s100, 1]cuda:0" = PlaceHolder[target=primals_25] +# %primals_28 : Tensor "i32[1, 1, s5, s10][s10*s5, s10*s5, s10, 1]cuda:0" = PlaceHolder[target=primals_28] +# %mul_19 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, 0.6931471805599453), kwargs = {}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %mul_19, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_15, %primals_18, %primals_20, %primals_23, %primals_25, %primals_28, 128, 128, %mask_graph0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem_4 +triton_tem_fused_mul_1 = async_compile.triton('triton_tem_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_mul_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_mul_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128, 1024, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp21 = (ds) + grad_scores = tmp21 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp22 = (qkT) + post_mod_scores = tmp22 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp23 = (m) + tmp24 = tl.full([1], 0, tl.int32) + tmp25 = tmp23 < tmp24 + tmp26 = (n) + tmp27 = tmp26 <= tmp23 + tmp28 = tmp25 & tmp27 + tmp29 = tmp23 >= tmp24 + tmp30 = tmp26 < tmp24 + tmp31 = tmp29 & tmp30 + tmp32 = tmp30 == 0 + tmp33 = tmp29 & tmp32 + tmp34 = tmp23 - tmp24 + tmp35 = tl.full([1], 16, tl.int32) + tmp36 = tl.where((tmp34 < 0) != (tmp35 < 0), tl.where(tmp34 % tmp35 != 0, tmp34 // tmp35 - 1, tmp34 // tmp35), tmp34 // tmp35) + tmp37 = tmp26 - tmp24 + tmp38 = tl.where((tmp37 < 0) != (tmp35 < 0), tl.where(tmp37 % tmp35 != 0, tmp37 // tmp35 - 1, tmp37 // tmp35), tmp37 // tmp35) + tmp39 = tmp36 == tmp38 + tmp40 = tmp33 & tmp39 + tmp41 = tmp31 | tmp40 + tmp42 = tmp28 | tmp41 + mask_mod_output = tmp42 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp43 = (dsT) + grad_scores = tmp43 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_14 + s28 = primals_16 + s4 = primals_17 + s56 = primals_19 + s53 = primals_22 + s84 = primals_21 + s100 = primals_24 + s10 = primals_27 + s5 = primals_26 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_15, (1, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_18, (1, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_20, (1, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_23, (1, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_25, (1, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_28, (1, 1, s5, s10), (s10*s5, s10*s5, s10, 1)) + assert_size_stride(getitem, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (1, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + assert_size_stride(tangents_2, (1, 32, s37), (32*max(1, s37), max(1, s37), 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + triton_red_fused_mul_0_xnumel = 32*s37 + stream0 = get_raw_stream(0) + triton_red_fused_mul_0.run(getitem, tangents_1, tangents_2, buf1, s37, triton_red_fused_mul_0_xnumel, 128, stream=stream0) + del getitem + del tangents_2 + buf3 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + buf5 = empty_strided_cuda((1, 8, s0, 128), (1024*s0, 128, 1024, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.mul] + stream0 = get_raw_stream(0) + triton_tem_fused_mul_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_20, primals_23, primals_15, primals_18, primals_25, primals_28, buf5, s37, s0, s99, s22, s72, s56, s53, s84, 4*((127 + s37) // 128) + ((127 + s0) // 128), 1, 8, stream=stream0) + del buf1 + del getitem_1 + del primals_13 + del primals_15 + del primals_18 + del primals_2 + del primals_20 + del primals_23 + del primals_25 + del primals_28 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 528 + primals_11 = 528 + primals_7 = 5 + primals_8 = 5 + primals_12 = 5 + primals_14 = 5 + primals_16 = 5 + primals_17 = 5 + primals_19 = 5 + primals_22 = 5 + primals_21 = 5 + primals_24 = 5 + primals_27 = 5 + primals_26 = 5 + primals_2 = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_4 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 8, 528, 128), (540672, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) + primals_9 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_20 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_23 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + primals_25 = rand_strided((1, 1, 5), (5, 5, 1), device='cuda:0', dtype=torch.int32) + primals_28 = rand_strided((1, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.int32) + getitem = rand_strided((1, 32, 528, 128), (2162688, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + getitem_1 = rand_strided((1, 32, 528), (16896, 528, 1), device='cuda:0', dtype=torch.float32) + tangents_1 = rand_strided((1, 32, 528, 128), (2162688, 67584, 128, 1), device='cuda:0', dtype=torch.bfloat16) + tangents_2 = rand_strided((1, 32, 528), (16896, 528, 1), device='cuda:0', dtype=torch.float32) + fn = lambda: call([primals_10, primals_11, primals_7, primals_8, primals_12, primals_14, primals_16, primals_17, primals_19, primals_22, primals_21, primals_24, primals_27, primals_26, primals_2, primals_4, primals_6, primals_9, primals_13, primals_15, primals_18, primals_20, primals_23, primals_25, primals_28, getitem, getitem_1, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/zi/0d93da4ce1731e5ac6f79f660c60ce3c57cb1ebf9808c2c1e4dc5843cceeb997.best_config b/progress/SpecForge/cache/compiled_kernels/zi/0d93da4ce1731e5ac6f79f660c60ce3c57cb1ebf9808c2c1e4dc5843cceeb997.best_config new file mode 100644 index 0000000000000000000000000000000000000000..3514cd4b7efbcfec5a9027a69143f1b0c3ed176a --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zi/0d93da4ce1731e5ac6f79f660c60ce3c57cb1ebf9808c2c1e4dc5843cceeb997.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "7cced77f371acaa5aa7d90332a90e0c907727cfefb71d9cc9d997c24557fc44f", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "INTOFYV77CKWGYMNYQFUZRDV3WKWLRKMHMFBFKCSJUXNCIG7GF7Q"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/zi/czisxx2rasv3jxm6nvp2yhkwlbob3ysmyfupuctvxqfwp52ekm3f.py b/progress/SpecForge/cache/compiled_kernels/zi/czisxx2rasv3jxm6nvp2yhkwlbob3ysmyfupuctvxqfwp52ekm3f.py new file mode 100644 index 0000000000000000000000000000000000000000..21d0aa456d2a5005dc6f044daff64dcb524159c0 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zi/czisxx2rasv3jxm6nvp2yhkwlbob3ysmyfupuctvxqfwp52ekm3f.py @@ -0,0 +1,28 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/zq/czq2zwxrzhcb4scf4ilvtlcyktcbslyon5eoixfxwfjhprt5ucao.py b/progress/SpecForge/cache/compiled_kernels/zq/czq2zwxrzhcb4scf4ilvtlcyktcbslyon5eoixfxwfjhprt5ucao.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac1ce5933247e63b8f96f313a290e7c6f2c910f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zq/czq2zwxrzhcb4scf4ilvtlcyktcbslyon5eoixfxwfjhprt5ucao.py @@ -0,0 +1,876 @@ +# AOT ID: ['3_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/s2/cs2ckboafjbrrbo5wr4xu4knlgkjyyoeypuyz3hhzzdabjrhvnjw.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:6" = PlaceHolder[target=primals_6] +# %buf0 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf0] +# %buf1 : Tensor "f32[1, 32, 32, s37][1024*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf1] +# %primals_10 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[1, 1, 1][1, 1, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[1, 1, 1, 1][1, 1, 1, 1]cuda:6" = PlaceHolder[target=primals_12] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %buf2 +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=2, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_M': '*fp32', 'arg_L': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'SM_SCALE': 0.08838834764831845, 'SPLIT_KV': 32, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M': 256, 'SAFE_M_BOUNDARY': False, 'SAFE_N_BOUNDARY': True, 'BLOCK_N': 64, 'SPARSE_KV_BLOCK_SIZE': 128, 'USE_TMA': False}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + M = arg_M + L = arg_L + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = 4096*ks0, 512, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + stride_mz, stride_mt, stride_mh, stride_mm = 1024*ks0, 32*ks0, ks0, 1 + stride_lz, stride_lt, stride_lh, stride_lm = 1024*ks0, 32*ks0, ks0, 1 + + + Z = 1 + ZKV = 1 + HKV = 8 + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = ks0 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0).to(INDEX_DTYPE) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0).to(INDEX_DTYPE) % HKV + off_t = tl.program_id(1).to(INDEX_DTYPE) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + K = K + k_offset + V = V + v_offset + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = 1, 1, 1 + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = 1, 1, 1, 1 + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + MAX_KV_IDX = 1 + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + desc_k = None + desc_v = None + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = (block_n_start // SPARSE_KV_MULTIPLE) % (MAX_KV_IDX) + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + off_n, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0 + 131072*idx_z*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_t*ks0, acc.shape)), acc, mask) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831845 + SPLIT_KV : tl.constexpr = 32 + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M : tl.constexpr = 256 + SAFE_M_BOUNDARY : tl.constexpr = False + SAFE_N_BOUNDARY : tl.constexpr = True + BLOCK_N : tl.constexpr = 64 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + USE_TMA : tl.constexpr = False + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_M, arg_L, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ty/ctye46k4gahd2e44muti52bzvk5ppd4o67yh77qc74mlf4iqkey7.py +# Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem_1 +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf4 : Tensor = PlaceHolder[target=buf4] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem_1 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=2] = call_function[target=operator.getitem](args = (%flex_attention, 1), kwargs = {}) +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %buf5,%buf7,%getitem_1,%mul_15 +triton_per_fused_mul_1 = async_compile.triton('triton_per_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 2048, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_mul_1(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + x2 = (xindex % ks0) + x3 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp5 = tl.load(in_ptr1 + (x0 + 32*ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, float("-inf")) + tmp4 = triton_helpers.max2(tmp3, 1)[:, None].to(tl.float32) + tmp6 = float("-inf") + tmp7 = tmp4 == tmp6 + tmp8 = tmp0 - tmp4 + tmp9 = 0.0 + tmp10 = tl.where(tmp7, tmp9, tmp8) + tmp11 = libdevice.exp2(tmp10) + tmp12 = tmp5 * tmp11 + tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) + tmp15 = tl.where(xmask, tmp13, 0) + tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32) + tmp17 = 1.0 + tmp18 = tl.where(tmp7, tmp17, tmp16) + tmp19 = libdevice.log2(tmp18) + tmp20 = tmp19 + tmp4 + tmp21 = 0.6931471805599453 + tmp22 = tmp20 * tmp21 + tl.store(out_ptr2 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp20, xmask) + tl.store(out_ptr3 + (x2 + x3*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp22, xmask) + tl.store(out_ptr0 + (x0), tmp4, xmask) + tl.store(out_ptr1 + (x0), tmp16, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/gs/cgsdbajyloxe47jtfxdbfdi7ye3lgfn4qhfvprge2u6zknepvixf.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention, getitem +# Graph fragment: +# %buf2 : Tensor "f32[1, 32, 32, s37, 128][131072*s37, 4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf2] +# %buf5 : Tensor "f32[1, 1, 32, s37][32*s37, 32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf5] +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %buf8 : Tensor "f32[1, 32, s37, 128][4096*s37, 128*s37, 128, 1]cuda:6" = PlaceHolder[target=buf8] +# %buf7 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:6" = PlaceHolder[target=buf7] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_8, %primals_9, %primals_10, %primals_7, %primals_11, %primals_12, %primals_13, %primals_14, %primals_15, %primals_16, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# %getitem : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:6"[num_users=1] = call_function[target=operator.getitem](args = (%flex_attention, 0), kwargs = {}) +# return %buf8,%getitem +triton_per_fused_2 = async_compile.triton('triton_per_fused_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 262144, 'r0_': 32}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_2', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': None, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x5 = xindex + x1 = xindex // 128 + x0 = (xindex % 128) + x3 = ((xindex // 128) % ks0) + x4 = xindex // ks1 + tmp0 = tl.load(in_ptr0 + (x5 + 4096*ks0*r0_2), None) + tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last') + tmp4 = tl.load(in_ptr2 + (x1 + 32*ks0*r0_2), None, eviction_policy='evict_last') + tmp13 = tl.load(in_ptr3 + (x1), None, eviction_policy='evict_last') + tmp2 = float("-inf") + tmp3 = tmp1 == tmp2 + tmp5 = tmp4 - tmp1 + tmp6 = 0.0 + tmp7 = tl.where(tmp3, tmp6, tmp5) + tmp8 = libdevice.exp2(tmp7) + tmp9 = tmp0 * tmp8 + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = tl.sum(tmp10, 1)[:, None].to(tl.float32) + tmp14 = 1.0 + tmp15 = tl.where(tmp3, tmp14, tmp13) + tmp16 = (tmp12 / tmp15) + tmp17 = tmp16.to(tl.float32) + tl.store(out_ptr1 + (x0 + 128*x4 + 4096*x3), tmp17, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s37 = primals_8 + s71 = primals_9 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_7, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_10, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_11, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_12, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_13, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_14, (1, 1, 1, 1), (1, 1, 1, 1)) + assert_size_stride(primals_15, (1, 1, 1), (1, 1, 1)) + assert_size_stride(primals_16, (1, 1, 1, 1), (1, 1, 1, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf0 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, 32, s37), (1024*s37, 32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, 32, s37, 128), (131072*s37, 4096*s37, 128*s37, 128, 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream6 = get_raw_stream(6) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_10, primals_7, primals_11, primals_12, buf2, s37, s0, 8, 32, 1, stream=stream6) + buf5 = empty_strided_cuda((1, 1, 32, s37), (32*s37, 32*s37, s37, 1), torch.float32) + buf7 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf10 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + buf11 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [flex_attention, lse_scaled], Original ATen: [aten.mul] + triton_per_fused_mul_1_xnumel = 32*s37 + stream6 = get_raw_stream(6) + triton_per_fused_mul_1.run(buf0, buf1, buf5, buf7, buf10, buf11, s37, triton_per_fused_mul_1_xnumel, 32, stream=stream6) + del buf1 + ps0 = 128*s37 + buf9 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + triton_per_fused_2_xnumel = 4096*s37 + stream6 = get_raw_stream(6) + triton_per_fused_2.run(buf2, buf5, buf0, buf7, buf9, s37, ps0, triton_per_fused_2_xnumel, 32, stream=stream6) + del buf0 + del buf2 + del buf5 + del buf7 + return (buf9, buf11, primals_2, primals_4, primals_6, primals_7, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, buf9, buf10, s37, s0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 48 + primals_2 = rand_strided((1, 32, 48, 128), (196608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = 48 + primals_4 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_5 = 48 + primals_6 = rand_strided((1, 8, 48, 128), (49152, 128, 1024, 1), device='cuda:6', dtype=torch.bfloat16) + primals_7 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_8 = 48 + primals_9 = 48 + primals_10 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_13 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 1), (1, 1, 1), device='cuda:6', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 1, 1), (1, 1, 1, 1), device='cuda:6', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/compiled_kernels/zx/ae120ec1628489609cba429fa5ac768b58094f300cecd0395c58cd8d18978c5b.best_config b/progress/SpecForge/cache/compiled_kernels/zx/ae120ec1628489609cba429fa5ac768b58094f300cecd0395c58cd8d18978c5b.best_config new file mode 100644 index 0000000000000000000000000000000000000000..47a5048cea1877a7cc5c8641104e80d8f141835f --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zx/ae120ec1628489609cba429fa5ac768b58094f300cecd0395c58cd8d18978c5b.best_config @@ -0,0 +1 @@ +{"XBLOCK": 4, "R0_BLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "48464ea7d171263ae4fed5184e32a30841f1081b8df295ec1f8e2f76e5287c9d", "found_by_coordesc": false, "time_taken_ms": 42, "triton_cache_hash": "VWP5KW7V6DWVQIOQUDOXJUC4FECFU4V5EZC2H76Y54HBZR6735SA"} \ No newline at end of file diff --git a/progress/SpecForge/cache/compiled_kernels/zx/czxnzxgrxghuyq6lg4yqhvr53wk5msdvvzbzikyl5wfl7qctkl5l.py b/progress/SpecForge/cache/compiled_kernels/zx/czxnzxgrxghuyq6lg4yqhvr53wk5msdvvzbzikyl5wfl7qctkl5l.py new file mode 100644 index 0000000000000000000000000000000000000000..10f7e6dc6f831b4e26b0a7b678c2ef5feec62779 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zx/czxnzxgrxghuyq6lg4yqhvr53wk5msdvvzbzikyl5wfl7qctkl5l.py @@ -0,0 +1,52 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 65536, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_mul_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 423936, 'r0_': 18087936}} +) +@triton.jit +def triton_red_fused_mul_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 35328 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 1104) + x1 = xindex // 1104 + x3 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_2 + 128*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp7 = tl.load(in_ptr2 + (x3), xmask, eviction_policy='evict_last') + tmp6 = tmp4.to(tl.float32) + tmp8 = 0.6931471805599453 + tmp9 = tmp7 * tmp8 + tmp10 = 1.4426950408889634 + tmp11 = tmp9 * tmp10 + tmp12 = tmp6 - tmp11 + tl.store(out_ptr1 + (x3), tmp12, xmask) diff --git a/progress/SpecForge/cache/compiled_kernels/zx/czxy4zj4bkbwn6ywrgd52ahdrfr3idcuqvqco65kh7wgxlg6ii5i.py b/progress/SpecForge/cache/compiled_kernels/zx/czxy4zj4bkbwn6ywrgd52ahdrfr3idcuqvqco65kh7wgxlg6ii5i.py new file mode 100644 index 0000000000000000000000000000000000000000..e0387f5ebe7901d254187a68da69b40f20e84767 --- /dev/null +++ b/progress/SpecForge/cache/compiled_kernels/zx/czxy4zj4bkbwn6ywrgd52ahdrfr3idcuqvqco65kh7wgxlg6ii5i.py @@ -0,0 +1,713 @@ +# AOT ID: ['1_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/or/cor47z6eozmuk4wdukvlmvzymu3noysuwi3iwfctlrm6sdlvjt4r.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[1, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 8, s0, 128][1024*s0, 128, 1024, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[1, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[1, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[1, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_14 : Tensor "i32[1, 1, 10][10, 10, 1]cuda:4" = PlaceHolder[target=primals_14] +# %primals_15 : Tensor "i32[1, 1, 10, 10][100, 100, 10, 1]cuda:4" = PlaceHolder[target=primals_15] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_14, %primals_15, %primals_16, %primals_17, %primals_18, %primals_19, 128, 128, %sdpa_mask0), 0.08838834764831845, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), ()), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831845, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128, 1024, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128, 1024, 1 + + ZQ = 1 + HQ = 32 + Q_LEN = ks0 + ZKV = 1 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 1 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = (m) + tmp2 = tl.full([1], 0, tl.int32) + tmp3 = tmp1 < tmp2 + tmp4 = (n) + tmp5 = tmp4 <= tmp1 + tmp6 = tmp3 & tmp5 + tmp7 = tmp1 >= tmp2 + tmp8 = tmp4 < tmp2 + tmp9 = tmp7 & tmp8 + tmp10 = tmp8 == 0 + tmp11 = tmp7 & tmp10 + tmp12 = tmp1 - tmp2 + tmp13 = tl.full([1], 16, tl.int32) + tmp14 = tl.where((tmp12 < 0) != (tmp13 < 0), tl.where(tmp12 % tmp13 != 0, tmp12 // tmp13 - 1, tmp12 // tmp13), tmp12 // tmp13) + tmp15 = tmp4 - tmp2 + tmp16 = tl.where((tmp15 < 0) != (tmp13 < 0), tl.where(tmp15 % tmp13 != 0, tmp15 // tmp13 - 1, tmp15 // tmp13), tmp15 // tmp13) + tmp17 = tmp14 == tmp16 + tmp18 = tmp11 & tmp17 + tmp19 = tmp9 | tmp18 + tmp20 = tmp6 | tmp19 + mask_mod_output = tmp20 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831845 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, out_ptr0, ks0, ks1, ks2, ks3, ks4, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/so/cson7svzlvyktivet44qjkefeudhlgzxa626d4cjoshcibhnkyyh.py +# Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] +# Source node to ATen node mapping: +# lse_scaled => mul_15 +# Graph fragment: +# %buf3 : Tensor = PlaceHolder[target=buf3] +# %mul_15 : Tensor "f32[1, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%getitem_1, 0.6931471805599453), kwargs = {}) +# return %mul_15 +triton_poi_fused_mul_1 = async_compile.triton('triton_poi_fused_mul_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 65536}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_mul_1(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x2 = xindex + x0 = (xindex % ks0) + x1 = triton_helpers.div_floor_integer(xindex, ks0) + tmp0 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + tmp1 = 0.6931471805599453 + tmp2 = tmp0 * tmp1 + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp2, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + assert_size_stride(primals_2, (1, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_6, (1, 8, s0, 128), (1024*s0, 128, 1024, 1)) + assert_size_stride(primals_9, (1, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (1, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_15, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(primals_16, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_17, (1, 1, 10, 10), (100, 100, 10, 1)) + assert_size_stride(primals_18, (1, 1, 10), (10, 10, 1)) + assert_size_stride(primals_19, (1, 1, 10, 10), (100, 100, 10, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((1, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((1, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_14, primals_15, buf2, s37, s0, s99, s22, s72, (127 + s37) // 128, 1, 32, stream=stream4) + del buf1 + buf5 = empty_strided_cuda((1, 32, s37), (32*max(1, s37), max(1, s37), 1), torch.float32) + # Topologically Sorted Source Nodes: [lse_scaled], Original ATen: [aten.mul] + triton_poi_fused_mul_1_xnumel = 32*s37 + stream4 = get_raw_stream(4) + triton_poi_fused_mul_1.run(buf0, buf5, s37, triton_poi_fused_mul_1_xnumel, stream=stream4) + return (buf2, buf5, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, buf2, buf0, s37, s0, s22, s72, s99, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1280 + primals_2 = rand_strided((1, 32, 1280, 128), (5242880, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = 1280 + primals_4 = rand_strided((1, 8, 1280, 128), (1310720, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_5 = 1280 + primals_6 = rand_strided((1, 8, 1280, 128), (1310720, 128, 1024, 1), device='cuda:4', dtype=torch.bfloat16) + primals_7 = 10 + primals_8 = 10 + primals_9 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_10 = 1280 + primals_11 = 1280 + primals_12 = 10 + primals_13 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_15 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_16 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_17 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + primals_18 = rand_strided((1, 1, 10), (10, 10, 1), device='cuda:4', dtype=torch.int32) + primals_19 = rand_strided((1, 1, 10, 10), (100, 100, 10, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/progress/SpecForge/cache/processed_dataset/tmp067p25oc b/progress/SpecForge/cache/processed_dataset/tmp067p25oc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp0ah8jku4 b/progress/SpecForge/cache/processed_dataset/tmp0ah8jku4 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp0sclu3x4 b/progress/SpecForge/cache/processed_dataset/tmp0sclu3x4 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp0ybctu6i b/progress/SpecForge/cache/processed_dataset/tmp0ybctu6i new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp0yg3ssj_ b/progress/SpecForge/cache/processed_dataset/tmp0yg3ssj_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp10nonl7_ b/progress/SpecForge/cache/processed_dataset/tmp10nonl7_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp11hcyov_ b/progress/SpecForge/cache/processed_dataset/tmp11hcyov_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp1360xfdd b/progress/SpecForge/cache/processed_dataset/tmp1360xfdd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp1hyq29gi b/progress/SpecForge/cache/processed_dataset/tmp1hyq29gi new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp1ofsvqn_ b/progress/SpecForge/cache/processed_dataset/tmp1ofsvqn_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp26fcprsh b/progress/SpecForge/cache/processed_dataset/tmp26fcprsh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp26gfzekz b/progress/SpecForge/cache/processed_dataset/tmp26gfzekz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp2cdl8jpo b/progress/SpecForge/cache/processed_dataset/tmp2cdl8jpo new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp2dd6va0s b/progress/SpecForge/cache/processed_dataset/tmp2dd6va0s new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp2leutbrd b/progress/SpecForge/cache/processed_dataset/tmp2leutbrd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp30oet8hi b/progress/SpecForge/cache/processed_dataset/tmp30oet8hi new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp37_ku82u b/progress/SpecForge/cache/processed_dataset/tmp37_ku82u new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp3hf5n1p_ b/progress/SpecForge/cache/processed_dataset/tmp3hf5n1p_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp3idfyz9u b/progress/SpecForge/cache/processed_dataset/tmp3idfyz9u new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp3jt1y_x6 b/progress/SpecForge/cache/processed_dataset/tmp3jt1y_x6 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp3ljbv8k5 b/progress/SpecForge/cache/processed_dataset/tmp3ljbv8k5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp3yfdqo8i b/progress/SpecForge/cache/processed_dataset/tmp3yfdqo8i new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp464kq8_e b/progress/SpecForge/cache/processed_dataset/tmp464kq8_e new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp4an6pbf8 b/progress/SpecForge/cache/processed_dataset/tmp4an6pbf8 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp4komzmfb b/progress/SpecForge/cache/processed_dataset/tmp4komzmfb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp4y_e7dnf b/progress/SpecForge/cache/processed_dataset/tmp4y_e7dnf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp4ye9r2ux b/progress/SpecForge/cache/processed_dataset/tmp4ye9r2ux new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp53b0wc3z b/progress/SpecForge/cache/processed_dataset/tmp53b0wc3z new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp5swzyv7w b/progress/SpecForge/cache/processed_dataset/tmp5swzyv7w new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp5y6mfzqa b/progress/SpecForge/cache/processed_dataset/tmp5y6mfzqa new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp61ktwa2_ b/progress/SpecForge/cache/processed_dataset/tmp61ktwa2_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp62_o669t b/progress/SpecForge/cache/processed_dataset/tmp62_o669t new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp64g57t2g b/progress/SpecForge/cache/processed_dataset/tmp64g57t2g new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6d4u6cce b/progress/SpecForge/cache/processed_dataset/tmp6d4u6cce new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6fgksr05 b/progress/SpecForge/cache/processed_dataset/tmp6fgksr05 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6fpnle35 b/progress/SpecForge/cache/processed_dataset/tmp6fpnle35 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6k0lqzd3 b/progress/SpecForge/cache/processed_dataset/tmp6k0lqzd3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6ofio5if b/progress/SpecForge/cache/processed_dataset/tmp6ofio5if new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6pvt4kj3 b/progress/SpecForge/cache/processed_dataset/tmp6pvt4kj3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp6ugd53en b/progress/SpecForge/cache/processed_dataset/tmp6ugd53en new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp775lcyav b/progress/SpecForge/cache/processed_dataset/tmp775lcyav new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp77vtjn2t b/progress/SpecForge/cache/processed_dataset/tmp77vtjn2t new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp78d5nj4c b/progress/SpecForge/cache/processed_dataset/tmp78d5nj4c new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp7bzl0fnv b/progress/SpecForge/cache/processed_dataset/tmp7bzl0fnv new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp7k2lh1mt b/progress/SpecForge/cache/processed_dataset/tmp7k2lh1mt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp7vhoyiei b/progress/SpecForge/cache/processed_dataset/tmp7vhoyiei new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp831eutse b/progress/SpecForge/cache/processed_dataset/tmp831eutse new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp8befu_n1 b/progress/SpecForge/cache/processed_dataset/tmp8befu_n1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp8cuxw0ef b/progress/SpecForge/cache/processed_dataset/tmp8cuxw0ef new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp8p2v6x2z b/progress/SpecForge/cache/processed_dataset/tmp8p2v6x2z new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp929q00qo b/progress/SpecForge/cache/processed_dataset/tmp929q00qo new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9_ewiipo b/progress/SpecForge/cache/processed_dataset/tmp9_ewiipo new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9f22nv8r b/progress/SpecForge/cache/processed_dataset/tmp9f22nv8r new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9foa4m02 b/progress/SpecForge/cache/processed_dataset/tmp9foa4m02 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9h1nz1wd b/progress/SpecForge/cache/processed_dataset/tmp9h1nz1wd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9iaq1nm7 b/progress/SpecForge/cache/processed_dataset/tmp9iaq1nm7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9nb47p73 b/progress/SpecForge/cache/processed_dataset/tmp9nb47p73 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp9r07griz b/progress/SpecForge/cache/processed_dataset/tmp9r07griz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_077eljf b/progress/SpecForge/cache/processed_dataset/tmp_077eljf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_5v8ks7l b/progress/SpecForge/cache/processed_dataset/tmp_5v8ks7l new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_a_euffg b/progress/SpecForge/cache/processed_dataset/tmp_a_euffg new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_bz9oa9c b/progress/SpecForge/cache/processed_dataset/tmp_bz9oa9c new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_dbock7t b/progress/SpecForge/cache/processed_dataset/tmp_dbock7t new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_kzu4626 b/progress/SpecForge/cache/processed_dataset/tmp_kzu4626 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_r_gjl6w b/progress/SpecForge/cache/processed_dataset/tmp_r_gjl6w new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_vh4yhzo b/progress/SpecForge/cache/processed_dataset/tmp_vh4yhzo new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmp_wa73gh1 b/progress/SpecForge/cache/processed_dataset/tmp_wa73gh1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpa_gay8ae b/progress/SpecForge/cache/processed_dataset/tmpa_gay8ae new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpafjawr6d b/progress/SpecForge/cache/processed_dataset/tmpafjawr6d new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpal44lmew b/progress/SpecForge/cache/processed_dataset/tmpal44lmew new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpat7544da b/progress/SpecForge/cache/processed_dataset/tmpat7544da new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpau12l9zt b/progress/SpecForge/cache/processed_dataset/tmpau12l9zt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpay_ep_s5 b/progress/SpecForge/cache/processed_dataset/tmpay_ep_s5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpban4_xzp b/progress/SpecForge/cache/processed_dataset/tmpban4_xzp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbkj90kq7 b/progress/SpecForge/cache/processed_dataset/tmpbkj90kq7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbo8e8aq7 b/progress/SpecForge/cache/processed_dataset/tmpbo8e8aq7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbonluxzh b/progress/SpecForge/cache/processed_dataset/tmpbonluxzh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbrilhr9q b/progress/SpecForge/cache/processed_dataset/tmpbrilhr9q new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbs4enf4x b/progress/SpecForge/cache/processed_dataset/tmpbs4enf4x new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbthyf357 b/progress/SpecForge/cache/processed_dataset/tmpbthyf357 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpbudl847h b/progress/SpecForge/cache/processed_dataset/tmpbudl847h new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpc3hjofuk b/progress/SpecForge/cache/processed_dataset/tmpc3hjofuk new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpc3yow3qu b/progress/SpecForge/cache/processed_dataset/tmpc3yow3qu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpc4p7dfre b/progress/SpecForge/cache/processed_dataset/tmpc4p7dfre new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpc51e2u2f b/progress/SpecForge/cache/processed_dataset/tmpc51e2u2f new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpccmzz19j b/progress/SpecForge/cache/processed_dataset/tmpccmzz19j new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpcf7gtxzg b/progress/SpecForge/cache/processed_dataset/tmpcf7gtxzg new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpd2z8ii0w b/progress/SpecForge/cache/processed_dataset/tmpd2z8ii0w new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpd7oe4izl b/progress/SpecForge/cache/processed_dataset/tmpd7oe4izl new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpd_5jr0hz b/progress/SpecForge/cache/processed_dataset/tmpd_5jr0hz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdb2alen_ b/progress/SpecForge/cache/processed_dataset/tmpdb2alen_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpde119jhi b/progress/SpecForge/cache/processed_dataset/tmpde119jhi new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdk2qhow7 b/progress/SpecForge/cache/processed_dataset/tmpdk2qhow7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdlmf6djw b/progress/SpecForge/cache/processed_dataset/tmpdlmf6djw new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdqz31gww b/progress/SpecForge/cache/processed_dataset/tmpdqz31gww new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdump9dm8 b/progress/SpecForge/cache/processed_dataset/tmpdump9dm8 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdyvpk4vz b/progress/SpecForge/cache/processed_dataset/tmpdyvpk4vz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpdzgk57rf b/progress/SpecForge/cache/processed_dataset/tmpdzgk57rf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpec_btlex b/progress/SpecForge/cache/processed_dataset/tmpec_btlex new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpedmna4xp b/progress/SpecForge/cache/processed_dataset/tmpedmna4xp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpegn88_cj b/progress/SpecForge/cache/processed_dataset/tmpegn88_cj new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpek2gpo1j b/progress/SpecForge/cache/processed_dataset/tmpek2gpo1j new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpek_l1tf7 b/progress/SpecForge/cache/processed_dataset/tmpek_l1tf7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpepmf5lsj b/progress/SpecForge/cache/processed_dataset/tmpepmf5lsj new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpf1t7rtyg b/progress/SpecForge/cache/processed_dataset/tmpf1t7rtyg new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpf402d5qr b/progress/SpecForge/cache/processed_dataset/tmpf402d5qr new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpf_jto4vu b/progress/SpecForge/cache/processed_dataset/tmpf_jto4vu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfigf0ato b/progress/SpecForge/cache/processed_dataset/tmpfigf0ato new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfkc7utz9 b/progress/SpecForge/cache/processed_dataset/tmpfkc7utz9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfm407iw3 b/progress/SpecForge/cache/processed_dataset/tmpfm407iw3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfmlgpbs_ b/progress/SpecForge/cache/processed_dataset/tmpfmlgpbs_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfp_nin1p b/progress/SpecForge/cache/processed_dataset/tmpfp_nin1p new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpfpjwv9jp b/progress/SpecForge/cache/processed_dataset/tmpfpjwv9jp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgcxxy_6l b/progress/SpecForge/cache/processed_dataset/tmpgcxxy_6l new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgd_oovw8 b/progress/SpecForge/cache/processed_dataset/tmpgd_oovw8 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgej0d6q5 b/progress/SpecForge/cache/processed_dataset/tmpgej0d6q5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgewvdpvq b/progress/SpecForge/cache/processed_dataset/tmpgewvdpvq new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgk1qo38r b/progress/SpecForge/cache/processed_dataset/tmpgk1qo38r new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgmw3migw b/progress/SpecForge/cache/processed_dataset/tmpgmw3migw new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpgqu55_qc b/progress/SpecForge/cache/processed_dataset/tmpgqu55_qc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmph1hago40 b/progress/SpecForge/cache/processed_dataset/tmph1hago40 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphja4d5ma b/progress/SpecForge/cache/processed_dataset/tmphja4d5ma new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphk49f1lu b/progress/SpecForge/cache/processed_dataset/tmphk49f1lu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphko3r8y7 b/progress/SpecForge/cache/processed_dataset/tmphko3r8y7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphkz_7cti b/progress/SpecForge/cache/processed_dataset/tmphkz_7cti new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphml_gxur b/progress/SpecForge/cache/processed_dataset/tmphml_gxur new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphr_bvg0d b/progress/SpecForge/cache/processed_dataset/tmphr_bvg0d new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphxa_62oe b/progress/SpecForge/cache/processed_dataset/tmphxa_62oe new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmphzufr1sx b/progress/SpecForge/cache/processed_dataset/tmphzufr1sx new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpi0_elt2a b/progress/SpecForge/cache/processed_dataset/tmpi0_elt2a new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpi111tte7 b/progress/SpecForge/cache/processed_dataset/tmpi111tte7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpi69f6zi_ b/progress/SpecForge/cache/processed_dataset/tmpi69f6zi_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpigbb6oxc b/progress/SpecForge/cache/processed_dataset/tmpigbb6oxc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpiha5a427 b/progress/SpecForge/cache/processed_dataset/tmpiha5a427 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpihnkl80j b/progress/SpecForge/cache/processed_dataset/tmpihnkl80j new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpiih5t1e5 b/progress/SpecForge/cache/processed_dataset/tmpiih5t1e5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpikeg10xs b/progress/SpecForge/cache/processed_dataset/tmpikeg10xs new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpikzf6n_0 b/progress/SpecForge/cache/processed_dataset/tmpikzf6n_0 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpix3rysxt b/progress/SpecForge/cache/processed_dataset/tmpix3rysxt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpj4kq8qsb b/progress/SpecForge/cache/processed_dataset/tmpj4kq8qsb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjahrtm_5 b/progress/SpecForge/cache/processed_dataset/tmpjahrtm_5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjbgmkoxi b/progress/SpecForge/cache/processed_dataset/tmpjbgmkoxi new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpji7e9fq_ b/progress/SpecForge/cache/processed_dataset/tmpji7e9fq_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjloe_hn9 b/progress/SpecForge/cache/processed_dataset/tmpjloe_hn9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjtf_isuy b/progress/SpecForge/cache/processed_dataset/tmpjtf_isuy new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjth21kue b/progress/SpecForge/cache/processed_dataset/tmpjth21kue new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpjwzudvsx b/progress/SpecForge/cache/processed_dataset/tmpjwzudvsx new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpk3xi390z b/progress/SpecForge/cache/processed_dataset/tmpk3xi390z new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpk70p53wq b/progress/SpecForge/cache/processed_dataset/tmpk70p53wq new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpk7_416rr b/progress/SpecForge/cache/processed_dataset/tmpk7_416rr new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpkhf0cgi9 b/progress/SpecForge/cache/processed_dataset/tmpkhf0cgi9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpkie75t2p b/progress/SpecForge/cache/processed_dataset/tmpkie75t2p new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpkqm72c01 b/progress/SpecForge/cache/processed_dataset/tmpkqm72c01 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpktl77nav b/progress/SpecForge/cache/processed_dataset/tmpktl77nav new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpkyvvy5yo b/progress/SpecForge/cache/processed_dataset/tmpkyvvy5yo new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpl2ner2wx b/progress/SpecForge/cache/processed_dataset/tmpl2ner2wx new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpl91sc79l b/progress/SpecForge/cache/processed_dataset/tmpl91sc79l new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpliqu89ww b/progress/SpecForge/cache/processed_dataset/tmpliqu89ww new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmplv2293l_ b/progress/SpecForge/cache/processed_dataset/tmplv2293l_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmplyok_act b/progress/SpecForge/cache/processed_dataset/tmplyok_act new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmplz5uvcdw b/progress/SpecForge/cache/processed_dataset/tmplz5uvcdw new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmplz6okuex b/progress/SpecForge/cache/processed_dataset/tmplz6okuex new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpm03nt4d3 b/progress/SpecForge/cache/processed_dataset/tmpm03nt4d3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpm83pokwe b/progress/SpecForge/cache/processed_dataset/tmpm83pokwe new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmb0ui3yj b/progress/SpecForge/cache/processed_dataset/tmpmb0ui3yj new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmc0tpypu b/progress/SpecForge/cache/processed_dataset/tmpmc0tpypu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmfekbe9i b/progress/SpecForge/cache/processed_dataset/tmpmfekbe9i new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmik3cctb b/progress/SpecForge/cache/processed_dataset/tmpmik3cctb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmq9d4mx2 b/progress/SpecForge/cache/processed_dataset/tmpmq9d4mx2 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmsqw1vlf b/progress/SpecForge/cache/processed_dataset/tmpmsqw1vlf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpmwzmst4p b/progress/SpecForge/cache/processed_dataset/tmpmwzmst4p new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpn1gyzso1 b/progress/SpecForge/cache/processed_dataset/tmpn1gyzso1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpn38dz5h6 b/progress/SpecForge/cache/processed_dataset/tmpn38dz5h6 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpn8pbxoe1 b/progress/SpecForge/cache/processed_dataset/tmpn8pbxoe1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpn_x27z85 b/progress/SpecForge/cache/processed_dataset/tmpn_x27z85 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpnf_3lzbr b/progress/SpecForge/cache/processed_dataset/tmpnf_3lzbr new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpnj3ei2by b/progress/SpecForge/cache/processed_dataset/tmpnj3ei2by new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpnouxc505 b/progress/SpecForge/cache/processed_dataset/tmpnouxc505 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpnuorsutf b/progress/SpecForge/cache/processed_dataset/tmpnuorsutf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpnwtbn427 b/progress/SpecForge/cache/processed_dataset/tmpnwtbn427 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpo4ty5tm9 b/progress/SpecForge/cache/processed_dataset/tmpo4ty5tm9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpo5gsroyu b/progress/SpecForge/cache/processed_dataset/tmpo5gsroyu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpoitw1i6p b/progress/SpecForge/cache/processed_dataset/tmpoitw1i6p new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpokcnx25g b/progress/SpecForge/cache/processed_dataset/tmpokcnx25g new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpokffg52b b/progress/SpecForge/cache/processed_dataset/tmpokffg52b new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpoln75rt2 b/progress/SpecForge/cache/processed_dataset/tmpoln75rt2 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpomi6h50r b/progress/SpecForge/cache/processed_dataset/tmpomi6h50r new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpoqi9wpei b/progress/SpecForge/cache/processed_dataset/tmpoqi9wpei new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmppi8d7p_k b/progress/SpecForge/cache/processed_dataset/tmppi8d7p_k new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmppn730vrk b/progress/SpecForge/cache/processed_dataset/tmppn730vrk new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmppv7vu_ew b/progress/SpecForge/cache/processed_dataset/tmppv7vu_ew new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmppx68k5vk b/progress/SpecForge/cache/processed_dataset/tmppx68k5vk new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpq737g7k6 b/progress/SpecForge/cache/processed_dataset/tmpq737g7k6 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpq9dptg05 b/progress/SpecForge/cache/processed_dataset/tmpq9dptg05 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpqe0qw440 b/progress/SpecForge/cache/processed_dataset/tmpqe0qw440 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpqh2n14a9 b/progress/SpecForge/cache/processed_dataset/tmpqh2n14a9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpqi7wf2l9 b/progress/SpecForge/cache/processed_dataset/tmpqi7wf2l9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpqspjox3k b/progress/SpecForge/cache/processed_dataset/tmpqspjox3k new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpquw_hnb2 b/progress/SpecForge/cache/processed_dataset/tmpquw_hnb2 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpr4l50fzf b/progress/SpecForge/cache/processed_dataset/tmpr4l50fzf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmprhzf1xo5 b/progress/SpecForge/cache/processed_dataset/tmprhzf1xo5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmprievwpnd b/progress/SpecForge/cache/processed_dataset/tmprievwpnd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmprsa5u2ga b/progress/SpecForge/cache/processed_dataset/tmprsa5u2ga new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmprvzyu20t b/progress/SpecForge/cache/processed_dataset/tmprvzyu20t new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpryndvauh b/progress/SpecForge/cache/processed_dataset/tmpryndvauh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmps1leefjx b/progress/SpecForge/cache/processed_dataset/tmps1leefjx new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmps5wol24t b/progress/SpecForge/cache/processed_dataset/tmps5wol24t new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmps9lugd96 b/progress/SpecForge/cache/processed_dataset/tmps9lugd96 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpsf7enrm7 b/progress/SpecForge/cache/processed_dataset/tmpsf7enrm7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpslwbrv5h b/progress/SpecForge/cache/processed_dataset/tmpslwbrv5h new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpsu1a2jzy b/progress/SpecForge/cache/processed_dataset/tmpsu1a2jzy new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpsxbmecff b/progress/SpecForge/cache/processed_dataset/tmpsxbmecff new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpt4_p1_xh b/progress/SpecForge/cache/processed_dataset/tmpt4_p1_xh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpt64sb98s b/progress/SpecForge/cache/processed_dataset/tmpt64sb98s new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpt7yygvy7 b/progress/SpecForge/cache/processed_dataset/tmpt7yygvy7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmptbo785au b/progress/SpecForge/cache/processed_dataset/tmptbo785au new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmptcqumhvq b/progress/SpecForge/cache/processed_dataset/tmptcqumhvq new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmptl5k1drf b/progress/SpecForge/cache/processed_dataset/tmptl5k1drf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpts1ct7es b/progress/SpecForge/cache/processed_dataset/tmpts1ct7es new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmptsxr9ydj b/progress/SpecForge/cache/processed_dataset/tmptsxr9ydj new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpu4_vhdkz b/progress/SpecForge/cache/processed_dataset/tmpu4_vhdkz new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpu77a1axg b/progress/SpecForge/cache/processed_dataset/tmpu77a1axg new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpual4pb62 b/progress/SpecForge/cache/processed_dataset/tmpual4pb62 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpub_mrir9 b/progress/SpecForge/cache/processed_dataset/tmpub_mrir9 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuf6xv9sv b/progress/SpecForge/cache/processed_dataset/tmpuf6xv9sv new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpugwfyz0_ b/progress/SpecForge/cache/processed_dataset/tmpugwfyz0_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuh_q_ypd b/progress/SpecForge/cache/processed_dataset/tmpuh_q_ypd new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpum131v7j b/progress/SpecForge/cache/processed_dataset/tmpum131v7j new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpusw1tuwh b/progress/SpecForge/cache/processed_dataset/tmpusw1tuwh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuu0i6jhu b/progress/SpecForge/cache/processed_dataset/tmpuu0i6jhu new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuvyvkj04 b/progress/SpecForge/cache/processed_dataset/tmpuvyvkj04 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuw5ry1pw b/progress/SpecForge/cache/processed_dataset/tmpuw5ry1pw new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuylzb3g0 b/progress/SpecForge/cache/processed_dataset/tmpuylzb3g0 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpuz23cuds b/progress/SpecForge/cache/processed_dataset/tmpuz23cuds new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv0wlj8v_ b/progress/SpecForge/cache/processed_dataset/tmpv0wlj8v_ new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv3k6qnda b/progress/SpecForge/cache/processed_dataset/tmpv3k6qnda new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv3spj_6m b/progress/SpecForge/cache/processed_dataset/tmpv3spj_6m new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv4e32tee b/progress/SpecForge/cache/processed_dataset/tmpv4e32tee new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv61t1_j3 b/progress/SpecForge/cache/processed_dataset/tmpv61t1_j3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv6nvc5ur b/progress/SpecForge/cache/processed_dataset/tmpv6nvc5ur new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpv9phbzv7 b/progress/SpecForge/cache/processed_dataset/tmpv9phbzv7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvakn62ja b/progress/SpecForge/cache/processed_dataset/tmpvakn62ja new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvf9gzysk b/progress/SpecForge/cache/processed_dataset/tmpvf9gzysk new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvhisnqjc b/progress/SpecForge/cache/processed_dataset/tmpvhisnqjc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvigu1ncb b/progress/SpecForge/cache/processed_dataset/tmpvigu1ncb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvlbxnmtf b/progress/SpecForge/cache/processed_dataset/tmpvlbxnmtf new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvxi2enj6 b/progress/SpecForge/cache/processed_dataset/tmpvxi2enj6 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpvxm5qyvm b/progress/SpecForge/cache/processed_dataset/tmpvxm5qyvm new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpw0wknpko b/progress/SpecForge/cache/processed_dataset/tmpw0wknpko new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpwx2isiul b/progress/SpecForge/cache/processed_dataset/tmpwx2isiul new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpwy96ckvp b/progress/SpecForge/cache/processed_dataset/tmpwy96ckvp new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxhbk2k8e b/progress/SpecForge/cache/processed_dataset/tmpxhbk2k8e new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxj68akt3 b/progress/SpecForge/cache/processed_dataset/tmpxj68akt3 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxn6xafp7 b/progress/SpecForge/cache/processed_dataset/tmpxn6xafp7 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxp3nf7ne b/progress/SpecForge/cache/processed_dataset/tmpxp3nf7ne new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxq6epe44 b/progress/SpecForge/cache/processed_dataset/tmpxq6epe44 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxvjjo98h b/progress/SpecForge/cache/processed_dataset/tmpxvjjo98h new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpxwh8vfqt b/progress/SpecForge/cache/processed_dataset/tmpxwh8vfqt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpy3oifg0b b/progress/SpecForge/cache/processed_dataset/tmpy3oifg0b new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpy9l3only b/progress/SpecForge/cache/processed_dataset/tmpy9l3only new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpyhb3ktyv b/progress/SpecForge/cache/processed_dataset/tmpyhb3ktyv new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpylvpa_i5 b/progress/SpecForge/cache/processed_dataset/tmpylvpa_i5 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpyqdoinnn b/progress/SpecForge/cache/processed_dataset/tmpyqdoinnn new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpyso7jmjt b/progress/SpecForge/cache/processed_dataset/tmpyso7jmjt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpz2nxtkoc b/progress/SpecForge/cache/processed_dataset/tmpz2nxtkoc new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpzfsolaz4 b/progress/SpecForge/cache/processed_dataset/tmpzfsolaz4 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpzj9bf_vv b/progress/SpecForge/cache/processed_dataset/tmpzj9bf_vv new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/cache/processed_dataset/tmpzpwd013g b/progress/SpecForge/cache/processed_dataset/tmpzpwd013g new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/configs/apollo.json b/progress/SpecForge/configs/apollo.json new file mode 100644 index 0000000000000000000000000000000000000000..dfd42202bcbc88f81d7ebb65e25d46b1c4456be5 --- /dev/null +++ b/progress/SpecForge/configs/apollo.json @@ -0,0 +1,8 @@ +{ + "rank": 256, + "proj": "random", + "scale_type": "channel", + "scale": 1, + "update_proj_gap": 200, + "proj_type": "std" +} diff --git a/progress/SpecForge/configs/apollo_mini.json b/progress/SpecForge/configs/apollo_mini.json new file mode 100644 index 0000000000000000000000000000000000000000..d10f1fd222cf7176577dbbaf87b54a3799aa2a5d --- /dev/null +++ b/progress/SpecForge/configs/apollo_mini.json @@ -0,0 +1,8 @@ +{ + "rank": 1, + "proj": "random", + "scale_type": "tensor", + "scale": 128, + "update_proj_gap": 200, + "proj_type": "std" +} diff --git a/progress/SpecForge/configs/deepseek-v2-lite-eagle3.json b/progress/SpecForge/configs/deepseek-v2-lite-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..da12c0fb4444a55773ac0f84f4360f3476a39d09 --- /dev/null +++ b/progress/SpecForge/configs/deepseek-v2-lite-eagle3.json @@ -0,0 +1,39 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 100000, + "eos_token_id": 100001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 10944, + "max_position_embeddings": 163840, + "max_window_layers": 64, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32.0, + "beta_slow": 1.0, + "factor": 40.0, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "rope_type": "yarn" + }, + "rope_theta": 10000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 102400, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/deepseek-v3-671b-eagle3.json b/progress/SpecForge/configs/deepseek-v3-671b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..147a5fdcd32c7ccd83248eec16dc709ed34e8bce --- /dev/null +++ b/progress/SpecForge/configs/deepseek-v3-671b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 29, + 57 + ], + "use_aux_hidden_state": true + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 40960, + "max_position_embeddings": 163840, + "model_type": "llama", + "num_attention_heads": 56, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.51.0", + "use_cache": true, + "vocab_size": 129280, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/gemma3-1b-eagle3.json b/progress/SpecForge/configs/gemma3-1b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e5e74eb16a3e47ac9ff4357106ff7c2afe4186da --- /dev/null +++ b/progress/SpecForge/configs/gemma3-1b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 1152, + "initializer_range": 0.02, + "intermediate_size": 6912, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 4, + "num_hidden_layers": 1, + "num_key_value_heads": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262145, + "draft_vocab_size": 32000, + "target_model_type": "gemma3_text" +} diff --git a/progress/SpecForge/configs/gpt-oss-120B-eagle3.json b/progress/SpecForge/configs/gpt-oss-120B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f4b36c7687620c95e90b4ec43ee8a53763826954 --- /dev/null +++ b/progress/SpecForge/configs/gpt-oss-120B-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 17, + 33 + ] + }, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2880, + "initializer_range": 0.02, + "intermediate_size": 17280, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 201088, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/gpt-oss-20B-eagle3.json b/progress/SpecForge/configs/gpt-oss-20B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e1d4b257d9644032488a31a67aca8719ffdbe33e --- /dev/null +++ b/progress/SpecForge/configs/gpt-oss-20B-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 11, + 21 + ] + }, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 2880, + "initializer_range": 0.02, + "intermediate_size": 17280, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 201088, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/ling-flash-2.0-eagle3.json b/progress/SpecForge/configs/ling-flash-2.0-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..0a9bea37c06ae29010eade7cd4b70cdf4e9e0316 --- /dev/null +++ b/progress/SpecForge/configs/ling-flash-2.0-eagle3.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 163584, + "eos_token_id": 163585, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.1", + "use_cache": true, + "vocab_size": 157184, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/llama3-70B-ealge3.json b/progress/SpecForge/configs/llama3-70B-ealge3.json new file mode 100644 index 0000000000000000000000000000000000000000..20d04f4d0dc09fe2894a7a35673b3a8afdaa8e32 --- /dev/null +++ b/progress/SpecForge/configs/llama3-70B-ealge3.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 128000, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 4096, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": true, + "vocab_size": 128256, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/llama3-8B-eagle3.json b/progress/SpecForge/configs/llama3-8B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..775ad6afee3c43946742b823b8f4e3d48af68b3c --- /dev/null +++ b/progress/SpecForge/configs/llama3-8B-eagle3.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": true, + "vocab_size": 128256, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/llama4-scout-17B-16E-eagle3.json b/progress/SpecForge/configs/llama4-scout-17B-16E-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..9c2bb5a81a3b5452836b0c6dcf1ba29e4ecc64e5 --- /dev/null +++ b/progress/SpecForge/configs/llama4-scout-17B-16E-eagle3.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.52.3", + "use_cache": true, + "vocab_size": 202048, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/longcat-flash-dflash.json b/progress/SpecForge/configs/longcat-flash-dflash.json new file mode 100644 index 0000000000000000000000000000000000000000..66e9b33a614a15dc3c5df35d9f6cb8aabe818d61 --- /dev/null +++ b/progress/SpecForge/configs/longcat-flash-dflash.json @@ -0,0 +1,45 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "dflash.DFlashDraftModel" + }, + "block_size": 16, + "bos_token_id": 1, + "dflash_config": { + "mask_token_id": 2, + "target_layer_ids": [1, 7, 13, 19, 25] + }, + "dtype": "bfloat16", + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 131072 + } diff --git a/progress/SpecForge/configs/longcat-flash-eagle3.json b/progress/SpecForge/configs/longcat-flash-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..7b3b921a22378353f010d1ee1ba03ec44610eb75 --- /dev/null +++ b/progress/SpecForge/configs/longcat-flash-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 131072, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads":16, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 131072, + "draft_vocab_size": 131072 + } diff --git a/progress/SpecForge/configs/phi4-eagle3.json b/progress/SpecForge/configs/phi4-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..05456a0d239653cdc898413860c6822d8a7cdec5 --- /dev/null +++ b/progress/SpecForge/configs/phi4-eagle3.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 100257, + "eos_token_id": 100257, + "pad_token_id": 100257, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 16384, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 1, + "num_key_value_heads": 10, + "rms_norm_eps": 1e-05, + "rope_theta": 250000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.47.0", + "use_cache": true, + "vocab_size": 100352, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen2-5-vl-7b-eagle3.json b/progress/SpecForge/configs/qwen2-5-vl-7b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..672193e3b1284badcb747356f1cbfcd402e19ccf --- /dev/null +++ b/progress/SpecForge/configs/qwen2-5-vl-7b-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/progress/SpecForge/configs/qwen2.5-7b-eagle3.json b/progress/SpecForge/configs/qwen2.5-7b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f16f6b8d07b120734f1eafd8c2e7881e424a57a1 --- /dev/null +++ b/progress/SpecForge/configs/qwen2.5-7b-eagle3.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "llama", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": 131072, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 16000 +} diff --git a/progress/SpecForge/configs/qwen2.5-vl-32b-eagle3.json b/progress/SpecForge/configs/qwen2.5-vl-32b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..76aa04cdf7cdf706443308f72f5e487cf6f510ff --- /dev/null +++ b/progress/SpecForge/configs/qwen2.5-vl-32b-eagle3.json @@ -0,0 +1,40 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 8192, + "max_window_layers": 28, + "model_type": "llama", + "target_model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "pretraining_tp": 1, + "rope_scaling": { + "type": "mrope", + "mrope_section": [ + 16, + 24, + 24 + ] + }, + "rope_theta": 1000000, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 + } diff --git a/progress/SpecForge/configs/qwen3-235B-A22B-eagle3.json b/progress/SpecForge/configs/qwen3-235B-A22B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..8e28c04a18a851c968252b1691b89dcdcff598b9 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-235B-A22B-eagle3.json @@ -0,0 +1,36 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 46, + 90 + ], + "use_aux_hidden_state": true + }, + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "draft_vocab_size": 32000, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 24576, + "max_position_embeddings": 40960, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "rope_scaling": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "vocab_size": 151936 +} diff --git a/progress/SpecForge/configs/qwen3-30B-A3B-eagle3.json b/progress/SpecForge/configs/qwen3-30B-A3B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..558cb18043a5bd182497536203de90a4a7672f35 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-30B-A3B-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 2048, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads":4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-32b-dflash.json b/progress/SpecForge/configs/qwen3-32b-dflash.json new file mode 100644 index 0000000000000000000000000000000000000000..7e35c0f4293cd42deff4009cabda1bd789bbe0bc --- /dev/null +++ b/progress/SpecForge/configs/qwen3-32b-dflash.json @@ -0,0 +1,68 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + + "auto_map": { + "AutoModel": "dflash.DFlashDraftModel" + }, + + "block_size": 16, + + "bos_token_id": 151643, + "eos_token_id": 151645, + + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [1, 8, 14, 21, 28, 34, 41, 48, 54, 61] + }, + + "dtype": "bfloat16", + + "head_dim": 128, + + "hidden_act": "silu", + "hidden_size": 5120, + "intermediate_size": 25600, + + "initializer_range": 0.02, + + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + + "max_position_embeddings": 40960, + "max_window_layers": 10, + + "model_type": "qwen3", + + "num_attention_heads": 64, + "num_key_value_heads": 8, + + "num_hidden_layers": 10, + "num_target_layers": 64, + + "rms_norm_eps": 1e-06, + + "rope_theta": 1000000, + "rope_scaling": null, + + "sliding_window": null, + "use_sliding_window": false, + + "tie_word_embeddings": false, + "use_cache": true, + + "vocab_size": 151936 +} \ No newline at end of file diff --git a/progress/SpecForge/configs/qwen3-32b-eagle3.json b/progress/SpecForge/configs/qwen3-32b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..cf128d9fb451833207c0a4293554357f324aea8c --- /dev/null +++ b/progress/SpecForge/configs/qwen3-32b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 25600, + "max_position_embeddings": 40960, + "max_window_layers": 64, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-4b-eagle3.json b/progress/SpecForge/configs/qwen3-4b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..41ae128fdcd532f1e31c6251819d29aedfa9d3e6 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-4b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2560, + "initializer_range": 0.02, + "intermediate_size": 9728, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-8b-dflash-lora.json b/progress/SpecForge/configs/qwen3-8b-dflash-lora.json new file mode 100644 index 0000000000000000000000000000000000000000..b41f701a12271eb45291347d6d97fbfaca7c5a95 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-8b-dflash-lora.json @@ -0,0 +1,10 @@ +{ + "lora_rank": 16, + "lora_alpha": 32, + "lora_dropout": 0.05, + "lora_target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], + "block_size": 16, + "mask_token_id": 151669, + "model_type": "qwen3", + "base_model": "Qwen/Qwen3-8B" +} diff --git a/progress/SpecForge/configs/qwen3-8b-dflash.json b/progress/SpecForge/configs/qwen3-8b-dflash.json new file mode 100644 index 0000000000000000000000000000000000000000..518860725a65bae6674c0af60643394ef174f2d9 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-8b-dflash.json @@ -0,0 +1,45 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "dflash.DFlashDraftModel" + }, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [1, 9, 17, 25, 33] + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/progress/SpecForge/configs/qwen3-8b-eagle3.json b/progress/SpecForge/configs/qwen3-8b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..b1fa44906d6decad8ccee5c8296699b1db5750f1 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-8b-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads":8 , + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-coder-30B-A3B-instruct-eagle3.json b/progress/SpecForge/configs/qwen3-coder-30B-A3B-instruct-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..f296c237973a83f40f4540a97bbc193e2593bb44 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-coder-30B-A3B-instruct-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 2048, + "max_window_layers": 48, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-coder-480B-A35B-instruct-eagle3.json b/progress/SpecForge/configs/qwen3-coder-480B-A35B-instruct-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..2f27c80cc017e811f8846f2161a977725e669086 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-coder-480B-A35B-instruct-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 262144, + "max_window_layers": 62, + "model_type": "llama", + "num_attention_heads": 96, + "num_hidden_layers": 1, + "num_key_value_heads":8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/configs/qwen3-next-80b-a3b-eagle3.json b/progress/SpecForge/configs/qwen3-next-80b-a3b-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..e94a2ea3407d784ee9fbd4b6a15b96cd7cadfec8 --- /dev/null +++ b/progress/SpecForge/configs/qwen3-next-80b-a3b-eagle3.json @@ -0,0 +1,29 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 262144, + "model_type": "llama", + "num_attention_heads": 16, + "num_hidden_layers": 1, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000000, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 + } diff --git a/progress/SpecForge/configs/qwq-32B-eagle3.json b/progress/SpecForge/configs/qwq-32B-eagle3.json new file mode 100644 index 0000000000000000000000000000000000000000..8f7d7908d5433c886a1725c1ec456f032ba80202 --- /dev/null +++ b/progress/SpecForge/configs/qwq-32B-eagle3.json @@ -0,0 +1,28 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 27648, + "max_position_embeddings": 40960, + "max_window_layers": 64, + "model_type": "qwen2", + "num_attention_heads": 40, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.1", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 152064, + "draft_vocab_size": 32000 +} diff --git a/progress/SpecForge/datasets/README.md b/progress/SpecForge/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8ddbef6d72d759dc06d8e59c15ef73c0ec29c204 --- /dev/null +++ b/progress/SpecForge/datasets/README.md @@ -0,0 +1,5 @@ +## Store Comprehensive Datasets Download Scripts + +| DatasetName | Github | Huggingface | command | +| -------- | -------- | -------- | -------- | +| ALLaVA-4V | [link](https://github.com/FreedomIntelligence/ALLaVA) | [link](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V) | download_laion.sh | diff --git a/progress/SpecForge/datasets/download_laion.sh b/progress/SpecForge/datasets/download_laion.sh new file mode 100644 index 0000000000000000000000000000000000000000..a64d061ebb5de06b2e87cfc3bcd2b38508b7009e --- /dev/null +++ b/progress/SpecForge/datasets/download_laion.sh @@ -0,0 +1,36 @@ + + +laion_root="allava_laion" + +mkdir $laion_root +cd $laion_root + + +# 1. download annotation files +## 1.1 caption +wget -c -O ALLaVA-Caption-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Caption-LAION-4V.json?download=true + +## 1.2 instruction +wget -c -O ALLaVA-Instruct-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Instruct-LAION-4V.json?download=true + + +# 2. download and upzip images +mkdir image_chunks + +## 2.1 download +for ((i=0; i<10; i++)) +do + wget -c -O image_chunks/images_$i.zip https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/image_chunks/images_$i.zip?download=true & +done + +mkdir -p images/ +wait + +## 2.2 unzip +for ((i=0; i<10; i++)) +do + unzip -j -o image_chunks/images_$i.zip -d images/ & # wait patiently, it takes a while... +done + +wait +echo "All done!" diff --git a/progress/SpecForge/docs/Makefile b/progress/SpecForge/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..6b8792c428564ace773add1f751f7c2471a8fe83 --- /dev/null +++ b/progress/SpecForge/docs/Makefile @@ -0,0 +1,58 @@ +# Minimal Makefile for Sphinx documentation +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SPHINXAUTOBUILD ?= sphinx-autobuild +SOURCEDIR = . +BUILDDIR = _build +PORT ?= 8003 + +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @echo "" + @echo "Additional targets:" + @echo " serve to build and serve documentation with auto-build and live reload" + +# Compile Notebook files and record execution time +compile: + @set -e; \ + echo "Starting Notebook compilation..."; \ + mkdir -p logs; \ + echo "Notebook execution timings:" > logs/timing.log; \ + START_TOTAL=$$(date +%s); \ + find $(SOURCEDIR) -path "*/_build/*" -prune -o -name "*.ipynb" -print0 | \ + parallel -0 -j3 --halt soon,fail=1 ' \ + NB_NAME=$$(basename {}); \ + START_TIME=$$(date +%s); \ + retry --delay=0 --times=2 -- \ + jupyter nbconvert --to notebook --execute --inplace "{}" \ + --ExecutePreprocessor.timeout=600 \ + --ExecutePreprocessor.kernel_name=python3; \ + RET_CODE=$$?; \ + END_TIME=$$(date +%s); \ + ELAPSED_TIME=$$((END_TIME - START_TIME)); \ + echo "$${NB_NAME}: $${ELAPSED_TIME}s" >> logs/timing.log; \ + exit $$RET_CODE' || exit 1; \ + END_TOTAL=$$(date +%s); \ + TOTAL_ELAPSED=$$((END_TOTAL - START_TOTAL)); \ + echo "---------------------------------" >> logs/timing.log; \ + echo "Total execution time: $${TOTAL_ELAPSED}s" >> logs/timing.log; \ + echo "All Notebook execution timings:" && cat logs/timing.log + +# Serve documentation with auto-build and live reload +serve: + @echo "Starting auto-build server at http://0.0.0.0:$(PORT)" + @$(SPHINXAUTOBUILD) "$(SOURCEDIR)" "$(BUILDDIR)/html" \ + --host 0.0.0.0 \ + --port $(PORT) \ + --watch $(SOURCEDIR) \ + --re-ignore ".*\.(ipynb_checkpoints|pyc|pyo|pyd|git)" + +.PHONY: help Makefile compile clean serve + +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + find . -name "*.ipynb" -exec nbstripout {} \; + rm -rf $(BUILDDIR) + rm -rf logs diff --git a/progress/SpecForge/docs/README.md b/progress/SpecForge/docs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..592f0e51a0f9be1b4aa959867fb526ed4003c149 --- /dev/null +++ b/progress/SpecForge/docs/README.md @@ -0,0 +1,55 @@ +# SpecForge Documentation + +We recommend new contributors to start from writing documentation, which helps you quickly understand the SpecForge codebase. +Most documentation files are located under the `docs/` folder. + +## Docs Workflow + +### Install Dependency + +```bash +apt-get update && apt-get install -y pandoc parallel retry +pip install -r requirements.txt +``` + +### Update Documentation + +Update your Jupyter notebooks in the appropriate subdirectories under `docs/`. If you add new files, remember to update `index.rst` (or relevant `.rst` files) accordingly. + +- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request. + +```bash +# 1) Compile all Jupyter notebooks +make compile # This step can take a long time (10+ mins). You can consider skipping this step if you can make sure your added files are correct. +make html + +# 2) Compile and Preview documentation locally with auto-build +# This will automatically rebuild docs when files change +# Open your browser at the displayed port to view the docs +bash serve.sh + +# 2a) Alternative ways to serve documentation +# Directly use make serve +make serve +# With custom port +PORT=8080 make serve + +# 3) Clean notebook outputs +# nbstripout removes notebook outputs so your PR stays clean +pip install nbstripout +find . -name '*.ipynb' -exec nbstripout {} \; + +# 4) Pre-commit checks and create a PR +# After these checks pass, push your changes and open a PR on your branch +pre-commit run --all-files +``` +--- + +## Documentation Style Guidelines + +- For common functionalities, we prefer **Jupyter Notebooks** over Markdown so that all examples can be executed and validated by our docs CI pipeline. For complex features (e.g., distributed serving), Markdown is preferred. +- Keep in mind the documentation execution time when writing interactive Jupyter notebooks. Each interactive notebook will be run and compiled against every commit to ensure they are runnable, so it is important to apply some tips to reduce the documentation compilation time: + - Use small models (e.g., `qwen/qwen2.5-0.5b-instruct`) for most cases to reduce server launch time. + - Reuse the launched server as much as possible to reduce server launch time. +- Do not use absolute links (e.g., `https://docs.sglang.ai/get_started/install.html`). Always prefer relative links (e.g., `../get_started/install.md`). +- Follow the existing examples to learn how to launch a server, send a query and other common styles. diff --git a/progress/SpecForge/docs/conf.py b/progress/SpecForge/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fef2396e931693259e82aee2e78cdb77d6c256 --- /dev/null +++ b/progress/SpecForge/docs/conf.py @@ -0,0 +1,188 @@ +import os +import sys +from datetime import datetime +from pathlib import Path + +sys.path.insert(0, os.path.abspath("../..")) + +DOCS_PATH = Path(__file__).parent +ROOT_PATH = DOCS_PATH.parent + +version_file = ROOT_PATH.joinpath("version.txt") +with open(version_file, "r") as f: + __version__ = f.read().strip() + +project = "SGLang" +copyright = f"2025-{datetime.now().year}, SpecForge" +author = "SpecForge Team" + +version = __version__ +release = __version__ + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", + "sphinx.ext.intersphinx", + "sphinx_tabs.tabs", + "myst_parser", + "sphinx_copybutton", + "sphinxcontrib.mermaid", + "nbsphinx", + "sphinx.ext.mathjax", +] + +nbsphinx_allow_errors = True +nbsphinx_execute = "never" + +autosectionlabel_prefix_document = True +nbsphinx_allow_directives = True + + +myst_enable_extensions = [ + "dollarmath", + "amsmath", + "deflist", + "colon_fence", + "html_image", + "substitution", +] + +myst_heading_anchors = 5 + +nbsphinx_kernel_name = "python3" +nbsphinx_execute_arguments = [ + "--InlineBackend.figure_formats={'svg', 'pdf'}", + "--InlineBackend.rc={'figure.dpi': 96}", +] + + +nb_render_priority = { + "html": ( + "application/vnd.jupyter.widget-view+json", + "application/javascript", + "text/html", + "image/svg+xml", + "image/png", + "image/jpeg", + "text/markdown", + "text/latex", + "text/plain", + ) +} + +myst_ref_domains = ["std", "py"] + +templates_path = ["_templates"] + +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", +} + +master_doc = "index" + +language = "en" + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +pygments_style = "sphinx" + +html_theme = "sphinx_book_theme" +html_logo = ROOT_PATH.joinpath("assets/logo.png").as_posix() +html_favicon = ROOT_PATH.joinpath("assets/logo.ico").as_posix() +html_title = project +html_copy_source = True +html_last_updated_fmt = "" + +html_theme_options = { + "repository_url": "https://github.com/sgl-project/sgl-project.github.io", + "repository_branch": "main", + "show_navbar_depth": 3, + "max_navbar_depth": 4, + "collapse_navbar": True, + "use_edit_page_button": True, + "use_source_button": True, + "use_issues_button": True, + "use_repository_button": True, + "use_download_button": True, + "use_sidenotes": True, + "show_toc_level": 2, +} + +html_context = { + "display_github": True, + "github_user": "sgl-project", + "github_repo": "sgl-project.github.io", + "github_version": "main", + "conf_py_path": "/docs/", +} + +html_static_path = ["_static", "spec_bundle/public"] +html_css_files = ["css/custom_log.css"] + + +def setup(app): + app.add_css_file("css/custom_log.css") + + +htmlhelp_basename = "sglangdoc" + +latex_elements = {} + +latex_documents = [ + (master_doc, "sglang.tex", "sglang Documentation", "SGLang Team", "manual"), +] + +man_pages = [(master_doc, "sglang", "sglang Documentation", [author], 1)] + +texinfo_documents = [ + ( + master_doc, + "sglang", + "sglang Documentation", + author, + "sglang", + "One line description of project.", + "Miscellaneous", + ), +] + +epub_title = project + +epub_exclude_files = ["search.html"] + +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True + +autodoc_preserve_defaults = True +navigation_with_keys = False + +autodoc_mock_imports = [ + "torch", + "transformers", + "triton", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.12", None), + "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest", None), + "pillow": ("https://pillow.readthedocs.io/en/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), +} + +html_theme = "sphinx_book_theme" + + +nbsphinx_prolog = """ +.. raw:: html + + +""" diff --git a/progress/SpecForge/docs/deploy.py b/progress/SpecForge/docs/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..75b7ea7f23dce0a5deb17c28d78b5cc59833a4d6 --- /dev/null +++ b/progress/SpecForge/docs/deploy.py @@ -0,0 +1,22 @@ +# Deploy the documents + +import os +from datetime import datetime + + +def run_cmd(cmd): + print(cmd) + os.system(cmd) + + +run_cmd("cd $DOC_SITE_PATH; git pull") + +# (Optional) Remove old files +# run_cmd("rm -rf $ALPA_SITE_PATH/*") + +run_cmd("cp -r _build/html/* $DOC_SITE_PATH") + +cmd_message = f"Update {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" +run_cmd( + f"cd $DOC_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin main" +) diff --git a/progress/SpecForge/docs/index.rst b/progress/SpecForge/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..bc2c694798793eddd894f5bd94fde539b9fb06b8 --- /dev/null +++ b/progress/SpecForge/docs/index.rst @@ -0,0 +1,53 @@ +SpecForge Documentation +======================= + +SpecForge is an ecosystem project developed by the SGLang team. It is a framework for training speculative decoding models so that you can smoothly port them over to the SGLang serving framework to speed up your inference. + + +.. toctree:: + :maxdepth: 1 + :caption: Get Started + + get_started/installation.md + get_started/about.md + +.. toctree:: + :maxdepth: 1 + :caption: Concepts + + concepts/speculative_decoding.md + concepts/EAGLE3.md + + +.. toctree:: + :maxdepth: 1 + :caption: Basic Usage + + basic_usage/data_preparation.md + basic_usage/training.md + +.. toctree:: + :maxdepth: 1 + :caption: Advanced Features + + advanced_features/customization.md + +.. toctree:: + :maxdepth: 1 + :caption: Community Resources + + community_resources/specbundle.md + community_resources/dashboard.md + +.. toctree:: + :maxdepth: 1 + :caption: Examples + + examples/llama3-eagle3-online.md + examples/llama3-eagle3-offline.md + +.. toctree:: + :maxdepth: 1 + :caption: Benchmarks + + benchmarks/benchmark.md diff --git a/progress/SpecForge/docs/requirements.txt b/progress/SpecForge/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1a7e5d4eba2f265cb2dce4eff31d770eb71125f3 --- /dev/null +++ b/progress/SpecForge/docs/requirements.txt @@ -0,0 +1,20 @@ +ipykernel +ipywidgets +jupyter_client +markdown>=3.4.0 +matplotlib +myst-parser +nbconvert +nbsphinx +pandoc +pillow +pydantic +sphinx +sphinx-book-theme +sphinx-copybutton +sphinx-tabs +nbstripout +sphinxcontrib-mermaid +urllib3<2.0.0 +gguf>=0.10.0 +sphinx-autobuild diff --git a/progress/SpecForge/docs/serve.sh b/progress/SpecForge/docs/serve.sh new file mode 100644 index 0000000000000000000000000000000000000000..049f767cf497a5fd92b1dac0af2fc13fdcf3fa69 --- /dev/null +++ b/progress/SpecForge/docs/serve.sh @@ -0,0 +1,3 @@ +# Clean and serve documentation with auto-build +make clean +make serve diff --git a/progress/SpecForge/examples/README.md b/progress/SpecForge/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ff5a6f3a8a5aae6c9ff7645afc266dc6cd7363bc --- /dev/null +++ b/progress/SpecForge/examples/README.md @@ -0,0 +1,9 @@ +# Run SpecForge Examples + +This folder contains the examples of running SpecForge on different models. The scripts can be invoked by the following command: + +```bash +bash examples/.sh [NUM_GPUS] [TP_SIZE] +``` + +We use the ShareGPT dataset for all the examples for now, but you can replace it with more robust datasets such as perfectblend, magpie-qwen2.5-pro-1m-v0.1, etc. diff --git a/progress/SpecForge/examples/run_deepseek_v2_lite_eagle3_online.sh b/progress/SpecForge/examples/run_deepseek_v2_lite_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..283c62ce80743d52bc29aaddc3b0b7a9829890c7 --- /dev/null +++ b/progress/SpecForge/examples/run_deepseek_v2_lite_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v2-lite +NUM_GPUS=${1:-8} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V2-Lite \ + --draft-model-config $ROOT_DIR/configs/deepseek-v2-lite-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v2-lite-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template deepseek \ + --target-model-backend hf \ + --cache-dir $ROOT_DIR/cache diff --git a/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_offline.sh b/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..4bede1dd50be44503365fb03fce8624e1bed2d4e --- /dev/null +++ b/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_offline.sh @@ -0,0 +1,43 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v3 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/perfect-blend-deepseek-v3 \ + --chat-template deepseek-v3 \ + --max-length 2048 \ + --tp-size 8 \ + --batch-size 4 \ + --sglang-mem-fraction-static 0.75 + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --draft-model-config $ROOT_DIR/configs/deepseek-v3-671b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/perfect-blend-deepseek-v3 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v3-671B-eagle3-perfect-blend-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template deepseek-v3 \ + --cache-dir $ROOT_DIR/cache diff --git a/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_online.sh b/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..2eb2769f9b5582a811a83305b72ea67bef5b514b --- /dev/null +++ b/progress/SpecForge/examples/run_deepseek_v3_671b_eagle3_online.sh @@ -0,0 +1,29 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for deepseek-v3 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path deepseek-ai/DeepSeek-V3 \ + --draft-model-config $ROOT_DIR/configs/deepseek-v3-671b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/deepseek-v3-671B-eagle3-perfect-blend-online \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template deepseek-v3 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.75 diff --git a/progress/SpecForge/examples/run_gemma3_1b_eagle3_online.sh b/progress/SpecForge/examples/run_gemma3_1b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1365069594baae0f7b8acbc45640c5ea39e0731 --- /dev/null +++ b/progress/SpecForge/examples/run_gemma3_1b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-1b +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-3-1b-it \ + --draft-model-config $ROOT_DIR/configs/gemma3-1b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma3-1b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gemma \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 10 diff --git a/progress/SpecForge/examples/run_gpt_oss_120b_eagle3_online.sh b/progress/SpecForge/examples/run_gpt_oss_120b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..eea5afbd8a6512945d35f4b005c338fc5f1671a8 --- /dev/null +++ b/progress/SpecForge/examples/run_gpt_oss_120b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for GPT-OSS-120B +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path openai/gpt-oss-120b \ + --draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/gpt-oss-20b-eagle3 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gpt-oss \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 diff --git a/progress/SpecForge/examples/run_gpt_oss_20b_eagle3_online.sh b/progress/SpecForge/examples/run_gpt_oss_20b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..55baeac1c49576e25162b6ff78558a7df8c4ee2d --- /dev/null +++ b/progress/SpecForge/examples/run_gpt_oss_20b_eagle3_online.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for GPT-OSS-20B +NUM_GPUS=${1:-8} +TP_SIZE=${2:-2} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path openai/gpt-oss-20b \ + --draft-model-config $ROOT_DIR/configs/gpt-oss-20B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend-gptoss-20B.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/perfect-blend-gptoss-20b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template gpt-oss \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 diff --git a/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_offline.sh b/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7f2925b4bd3381d0d9984a59db7e3b0f3699faa --- /dev/null +++ b/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_offline.sh @@ -0,0 +1,45 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for ling-flash-2.0 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/perfect-blend-ling-flash-2.0 \ + --chat-template ling-flash-2.0 \ + --max-length 2048 \ + --tp-size $TP_SIZE \ + --batch-size 4 \ + --sglang-mem-fraction-static 0.75 \ + --trust-remote-code + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --draft-model-config $ROOT_DIR/configs/ling-flash-2.0-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/perfect-blend-ling-flash-2.0 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/ling-flash-2.0-eagle3-perfect-blend-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template ling-flash-2.0 \ + --embedding-key 'model.word_embeddings.weight' \ + --cache-dir $ROOT_DIR/cache \ + --trust-remote-code diff --git a/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_online.sh b/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f9d1cc3d87905107b0708b99ec8a32a832831a2 --- /dev/null +++ b/progress/SpecForge/examples/run_ling_flash_2.0_eagle3_online.sh @@ -0,0 +1,30 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for ling-flash-2.0 +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path inclusionAI/Ling-flash-2.0 \ + --draft-model-config $ROOT_DIR/configs/ling-flash-2.0-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfect-blend.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/ling-flash-2.0-eagle3-perfect-blend-online \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 5e-5 \ + --max-length 2048 \ + --chat-template ling-flash-2.0 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.75 \ + --embedding-key 'model.word_embeddings.weight' \ + --trust-remote-code diff --git a/progress/SpecForge/examples/run_llama3.1_8b_eagle3_offline.sh b/progress/SpecForge/examples/run_llama3.1_8b_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..dffcbef845b727e6be2eeb2f24d63de2cc8b693f --- /dev/null +++ b/progress/SpecForge/examples/run_llama3.1_8b_eagle3_offline.sh @@ -0,0 +1,39 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# generate hidden states +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --output-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --chat-template llama3 \ + --max-length 4096 \ + --tp-size $TP_SIZE \ + --batch-size 32 + +# train eagle3 offline +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache diff --git a/progress/SpecForge/examples/run_llama3.1_8b_eagle3_online.sh b/progress/SpecForge/examples/run_llama3.1_8b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..d47c1797fa14aaaba21784f108ea2c270163b805 --- /dev/null +++ b/progress/SpecForge/examples/run_llama3.1_8b_eagle3_online.sh @@ -0,0 +1,29 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend sglang \ + --log-interval 10 \ + --sglang-mem-fraction-static 0.25 diff --git a/progress/SpecForge/examples/run_llama3.3_70b_eagle3_online.sh b/progress/SpecForge/examples/run_llama3.3_70b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ea80413df676a319d8a1a38eaeb036d21d3321d --- /dev/null +++ b/progress/SpecForge/examples/run_llama3.3_70b_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.3-70B-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama3-70B-ealge3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama3.3-70b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir $ROOT_DIR/cache \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_llama4_scout_eagle3_online.sh b/progress/SpecForge/examples/run_llama4_scout_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..73ed03a617297b64569cafecfd5bce9a2cf8f940 --- /dev/null +++ b/progress/SpecForge/examples/run_llama4_scout_eagle3_online.sh @@ -0,0 +1,25 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for llama3.1-8b +NUM_GPUS=${1:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-4-Scout-17B-16E-Instruct \ + --draft-model-config $ROOT_DIR/configs/llama4-scout-17B-16E-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/llama4-scout-17B-16E-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama4 \ + --cache-dir $ROOT_DIR/cache \ + --tp-size 8 \ + --embedding-key language_model.model.embed_tokens.weight \ diff --git a/progress/SpecForge/examples/run_longcat_flash_dflash_online.sh b/progress/SpecForge/examples/run_longcat_flash_dflash_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..7e5b7ec72c7bea213251143d3bd6f6d0202f4a15 --- /dev/null +++ b/progress/SpecForge/examples/run_longcat_flash_dflash_online.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=${SPECFORGE_DATA_NUM_PROC:-64} + +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} +WANDB_MODE=offline +SGL_JIT_DEEPGEMM_PRECOMPILE=false +SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path meituan-longcat/LongCat-Flash-Chat-FP8 \ + --target-model-backend sglang \ + --tp-size $NUM_GPUS \ + --sglang-attention-backend flashinfer \ + --sglang-mem-fraction-static 0.75 \ + --sglang-ep-size $NUM_GPUS \ + --draft-config-path $ROOT_DIR/configs/longcat-flash-dflash.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/longcat-flash-dflash-sharegpt \ + --num-epochs 6 \ + --batch-size 2 \ + --learning-rate 6e-4 \ + --warmup-ratio 0.04 \ + --max-grad-norm 1.0 \ + --max-length 3072 \ + --chat-template longcat \ + --random-anchor \ + --num-anchors 512 \ + --loss-decay-gamma 7.0 \ + --log-interval 50 \ + --save-interval 1000 \ + --report-to wandb \ + --wandb-project specforge-longcat-flash-dflash \ + --wandb-name longcat-flash-dflash-sharegpt \ + --mask-token-id 2 diff --git a/progress/SpecForge/examples/run_longcat_flash_eagle3_online.sh b/progress/SpecForge/examples/run_longcat_flash_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..f89cb502009610f0f19db6fef25c268ff1c8f641 --- /dev/null +++ b/progress/SpecForge/examples/run_longcat_flash_eagle3_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path meituan-longcat/LongCat-Flash-Chat-FP8 \ + --draft-model-config $ROOT_DIR/configs/longcat-flash-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/longcat-flash-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template longcat \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend sglang \ + --log-interval 10 \ + --sglang-mem-fraction-static 0.75 \ + --sglang-attention-backend flashinfer \ + --sglang-ep-size $NUM_GPUS diff --git a/progress/SpecForge/examples/run_phi4_eagle3_online.sh b/progress/SpecForge/examples/run_phi4_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..f306d22e71941b295bab03695a7f3c3187fc54d5 --- /dev/null +++ b/progress/SpecForge/examples/run_phi4_eagle3_online.sh @@ -0,0 +1,27 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path microsoft/phi-4 \ + --draft-model-config $ROOT_DIR/configs/phi4-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/phi4-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template phi4 \ + --cache-dir $ROOT_DIR/cache \ + --target-model-backend sglang \ + --embedding-key model.embed_tokens.weight diff --git a/progress/SpecForge/examples/run_qwen2.5_32b_vl_eagle3_online.sh b/progress/SpecForge/examples/run_qwen2.5_32b_vl_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..a7c86b0e502e19a1e39f42860d4804c768e84642 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen2.5_32b_vl_eagle3_online.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-32B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2.5-vl-32b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/allava4v_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen2.5-vl-32b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --target-model-backend sglang \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 4 \ + --sglang-mem-fraction-static 0.5 \ + --is-vlm \ + --min-pixels 200704 \ + --max-pixels 1003520 diff --git a/progress/SpecForge/examples/run_qwen2.5_7b_vl_eagle3_online.sh b/progress/SpecForge/examples/run_qwen2.5_7b_vl_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..e94e6e39882484f0e6590e72686d594ba4bf1ff0 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen2.5_7b_vl_eagle3_online.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# support tp1 train eagle3 for qwen2.5-vl-7b-instruct +NUM_GPUS=${1:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen2.5-VL-7B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen2-5-vl-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/allava4v_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen2.5-VL-7B-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 8192 \ + --dist-timeout 360 \ + --chat-template qwen2-vl \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 1 \ + --is-vlm \ + --min-pixels 50176 \ + --max-pixels 802816 diff --git a/progress/SpecForge/examples/run_qwen3_235b_a22b_eagle3.sh b/progress/SpecForge/examples/run_qwen3_235b_a22b_eagle3.sh new file mode 100644 index 0000000000000000000000000000000000000000..c96b42cb6267bce71fe520669544d9a59eb1cffc --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_235b_a22b_eagle3.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B +NUM_GPUS=8 +TP_SIZE=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path /workdir/huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct-FP8/\ + --draft-model-config $ROOT_DIR/configs/qwen3-next-80b-a3b-eagle3.json \ + --train-data-path /workdir/data_qwen80b/qwen3_80b_perfectblend_train_regen.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir /workdir/qwen3-80b-regen-blend \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir /workdir/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_qwen3_30b_a3b_eagle3_online.sh b/progress/SpecForge/examples/run_qwen3_30b_a3b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..29b5ac167b044ea321e8f25a5f6f0f5b088dc90c --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_30b_a3b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-30b-a3b-instruct-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_qwen3_8b_dflash_online.sh b/progress/SpecForge/examples/run_qwen3_8b_dflash_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..556a00764e8f173aaafd3b3c9bfbf292efb0741a --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_8b_dflash_online.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=32 +NUM_GPUS=${1:-8} + +ATTENTION_BACKEND=${2:-flex_attention} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path Qwen/Qwen3-8B \ + --draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfectblend_qwen3-8b_regen.jsonl \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-perfectblend \ + --num-epochs 6 \ + --batch-size 4 \ + --learning-rate 6e-4 \ + --warmup-ratio 0.04 \ + --max-grad-norm 1.0 \ + --max-length 3072 \ + --chat-template qwen \ + --attention-backend $ATTENTION_BACKEND \ + --random-anchor \ + --num-anchors 512 \ + --loss-decay-gamma 7.0 \ + --log-interval 50 \ + --save-interval 1000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-8b-dflash \ + --wandb-name qwen3-8b-dflash-perfectblend diff --git a/progress/SpecForge/examples/run_qwen3_8b_eagle3_online.sh b/progress/SpecForge/examples/run_qwen3_8b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..4aa79654701fb6c05f6430140657ba9e550e4f13 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_8b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp8 train eagle3 for Qwen3-4B/8B/32B up to tp_size = 8 +NUM_GPUS=${1:-1} +TP_SIZE=${2:-1} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-8B \ + --draft-model-config $ROOT_DIR/configs/qwen3-8b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh b/progress/SpecForge/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..b88d5fcdca1cf34ae6f0b050c0a5af390cf50c05 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_coder_30b_a3b_eagle3_online.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# Train EAGLE3 draft model for Qwen3-Coder-30B-A3B-Instruct +# Uses the regenerated OPC dataset and TP=4 on GPUs 4,5,6,7 + +# GPU Configuration - Use the later 4 GPUs (4,5,6,7) +export CUDA_VISIBLE_DEVICES=4,5,6,7 +NUM_GPUS=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-30B-A3B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc_regenerated.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-coder-30b-a3b-instruct-eagle3-opc-regen \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size 4 \ + --dist-timeout 60 \ + --log-interval 50 \ + --save-interval 5000 \ + --eval-interval 5000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-coder \ + --wandb-name qwen3-coder-30b-eagle3-tp4-opc-regen diff --git a/progress/SpecForge/examples/run_qwen3_coder_eagle3_offline.sh b/progress/SpecForge/examples/run_qwen3_coder_eagle3_offline.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7d0f272bfcd23a1073ee6ca012222b7a8a0df82 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_coder_eagle3_offline.sh @@ -0,0 +1,26 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwen3-coder +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-480B-A35B-Instruct \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-480B-A35B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc.jsonl \ + --train-hidden-states-path $ROOT_DIR/cache/hidden_states \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen3-Coder-480B-A35B-Instruct \ + --num-epochs 10 \ + --draft-micro-batch-size 1 \ + --draft-global-batch-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template qwen \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_qwen3_coder_eagle3_online.sh b/progress/SpecForge/examples/run_qwen3_coder_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..77f7803301b9678538dcbcde664ab8c74b5451a4 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_coder_eagle3_online.sh @@ -0,0 +1,33 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwen3-coder +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8 \ + --draft-model-config $ROOT_DIR/configs/qwen3-coder-480B-A35B-instruct-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/opc_regenerated.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/Qwen3-Coder-480B-A35B-Instruct-FP8 \ + --tp-size $TP_SIZE \ + --sglang-ep-size 2 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-5 \ + --ttt-length 13 \ + --sglang-mem-fraction-static 0.6 \ + --max-length 2048 \ + --chat-template qwen \ + --target-model-backend sglang \ + --save-interval 20000 \ + --eval-interval 20000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-480-coder-fp8 \ + --wandb-name qwen3-coder-480b-a35b-eagle3-tp8-ep2-opc-regen diff --git a/progress/SpecForge/examples/run_qwen3_next_80b_eagle3_online.sh b/progress/SpecForge/examples/run_qwen3_next_80b_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..14838913f92ca64ccc7fc49f70389d834bf815d7 --- /dev/null +++ b/progress/SpecForge/examples/run_qwen3_next_80b_eagle3_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +NUM_GPUS=${1:-8} +TP_SIZE=4 +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path $ROOT_DIR//Qwen/Qwen3-Next-80B-A3B-Instruct-FP8/\ + --draft-model-config $ROOT_DIR/configs/qwen3-next-80b-a3b-eagle3.json \ + --train-data-path $ROOT_DIR/data_qwen80b/qwen3_80b_perfectblend_train_regen.jsonl \ + --output-dir $ROOT_DIR/qwen3-80b-regen-blend \ + --num-epochs 2 \ + --batch-size 2 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --sglang-mem-fraction-static 0.5 \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --target-model-backend sglang diff --git a/progress/SpecForge/examples/run_qwq_eagle3_online.sh b/progress/SpecForge/examples/run_qwq_eagle3_online.sh new file mode 100644 index 0000000000000000000000000000000000000000..2b2fae6f19bba8a55a93e134f4cc848e18767d02 --- /dev/null +++ b/progress/SpecForge/examples/run_qwq_eagle3_online.sh @@ -0,0 +1,28 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for qwq-32b +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/QwQ-32B \ + --draft-model-config $ROOT_DIR/configs/qwq-32B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwq-32b-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang diff --git a/progress/SpecForge/scripts/__pycache__/train_dflash_lora.cpython-313.pyc b/progress/SpecForge/scripts/__pycache__/train_dflash_lora.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a29da314ae1aeed9a6a5a14e36f1f096a0cfcf6 Binary files /dev/null and b/progress/SpecForge/scripts/__pycache__/train_dflash_lora.cpython-313.pyc differ diff --git a/progress/SpecForge/scripts/eval_dflash_lora.log b/progress/SpecForge/scripts/eval_dflash_lora.log new file mode 100644 index 0000000000000000000000000000000000000000..0c516994051773830ad78dcddaeeef1c79a66194 --- /dev/null +++ b/progress/SpecForge/scripts/eval_dflash_lora.log @@ -0,0 +1,121 @@ +nohup: ignoring input + +***************************************** +Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +***************************************** +/workspace/hanrui/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/junquan/SpecForge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/hanrui/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/junquan/SpecForge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/hanrui/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/junquan/SpecForge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +/workspace/hanrui/specforge/lib/python3.11/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled. +We recommend installing via `pip install torch-c-dlpack-ext` + warnings.warn( +Set TORCH_CUDA_ARCH_LIST to 9.0 +/workspace/hanrui/junquan/SpecForge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend. + warnings.warn( +`torch_dtype` is deprecated! Use `dtype` instead! + Loading checkpoint shards: 0%| | 0/5 [00:00 Tuple[DFlashLoRADraftModel, OnlineDFlashLoRAModel]: + print_on_rank0(f"Loading base model from {args.model_path}") + + lora_rank = args.lora_rank + lora_alpha = args.lora_alpha + lora_dropout = args.lora_dropout + lora_target_modules = args.lora_target_modules + + if args.lora_config is not None: + with open(args.lora_config) as f: + lora_cfg = json.load(f) + lora_rank = lora_cfg.get("lora_rank", lora_rank) + lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) + lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) + lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) + print_on_rank0(f"Loaded LoRA config from {args.lora_config}") + + attn_impl = "flex_attention" if args.attention_backend == "flex_attention" else args.attn_implementation + + draft_model = DFlashLoRADraftModel.from_pretrained( + pretrained_model_name_or_path=args.model_path, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + lora_target_modules=lora_target_modules, + block_size=args.block_size, + mask_token_id=args.mask_token_id or 151669, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=args.trust_remote_code, + attn_implementation=attn_impl, + ) + + # Load LoRA weights from checkpoint + print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") + from peft import PeftModel + draft_model.model = PeftModel.from_pretrained( + draft_model.model.base_model.model, args.ckpt_dir + ) + + online_model = OnlineDFlashLoRAModel( + draft_model=draft_model, + block_size=args.block_size, + mask_token_id=args.mask_token_id or 151669, + loss_decay_gamma=None, + attention_backend=args.attention_backend, + lm_head_chunk_size=args.lm_head_chunk_size, + ) + + return draft_model, online_model + + +def build_dataloader(args, tokenizer): + import hashlib + + cache_params_string = ( + f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + rank = dist.get_rank() + + if os.path.isdir(args.data_path): + dataset = load_dataset(args.data_path, split="train") + else: + dataset = load_dataset("json", data_files=args.data_path)["train"] + + if args.num_samples is not None: + dataset = dataset.select(range(min(args.num_samples, len(dataset)))) + print_on_rank0(f"Using {len(dataset)} samples for eval") + + dataset_kwargs = dict( + dataset=dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + + if rank == 0: + eval_dataset = build_eagle3_dataset(**dataset_kwargs) + dist.barrier() + if rank != 0: + eval_dataset = build_eagle3_dataset(**dataset_kwargs) + + min_loss_tokens = 2 * args.block_size + original_size = len(eval_dataset) + eval_dataset = eval_dataset.filter( + lambda x: x["loss_mask"].sum() >= min_loss_tokens + ) + print_on_rank0(f"Filtered dataset: {original_size} -> {len(eval_dataset)} samples") + + dataloader = prepare_dp_dataloaders( + eval_dataset, + args.batch_size, + num_workers=args.num_workers, + shuffle=False, + process_group=get_dp_group(), + ) + return dataloader + + +def main(): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + warnings.filterwarnings( + "ignore", + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", + ) + + args = parse_args() + + init_distributed(timeout=args.dist_timeout, tp_size=1) + print_with_rank("Initialized distributed") + + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + if args.mask_token_id is not None: + mask_token_id = args.mask_token_id + elif tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + else: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = tokenizer.mask_token_id + print_on_rank0(f"Using mask_token_id: {mask_token_id}") + args.mask_token_id = mask_token_id + + draft_model, online_model = build_model(args) + draft_model.mask_token_id = mask_token_id + online_model.mask_token_id = mask_token_id + + dataloader = build_dataloader(args, tokenizer) + + draft_model.eval() + online_model.eval() + + total_acc = 0.0 + total_loss = 0.0 + total_steps = 0 + + print_on_rank0(f"Starting eval on {len(dataloader)} batches...") + + with torch.no_grad(): + for step, data in enumerate(dataloader): + input_ids = data["input_ids"].cuda() + attention_mask = data["attention_mask"].cuda() + loss_mask = data["loss_mask"].cuda() + + loss, accuracy = online_model( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + context_len=args.context_len, + ) + + total_acc += accuracy.item() + total_loss += loss.item() + total_steps += 1 + + if (step + 1) % args.log_interval == 0: + avg_acc = total_acc / total_steps + avg_accepted_length = avg_acc * (args.block_size - 1) + print_on_rank0( + f"Step {step + 1}/{len(dataloader)} | " + f"loss: {total_loss / total_steps:.4f} | " + f"acc: {avg_acc:.4f} | " + f"accepted_length: {avg_accepted_length:.4f}" + ) + + # All-reduce across ranks + acc_t = torch.tensor(total_acc / total_steps, device="cuda") + loss_t = torch.tensor(total_loss / total_steps, device="cuda") + dist.all_reduce(acc_t) + dist.all_reduce(loss_t) + world_size = dist.get_world_size() + + final_acc = acc_t.item() / world_size + final_loss = loss_t.item() / world_size + final_accepted_length = final_acc * (args.block_size - 1) + + print_on_rank0( + f"\n=== Eval Results ===\n" + f" Loss: {final_loss:.4f}\n" + f" Accuracy: {final_acc:.4f}\n" + f" Accepted Length: {final_accepted_length:.4f} / {args.block_size - 1}\n" + f" Num batches: {total_steps}\n" + ) + + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/prepare_data.py b/progress/SpecForge/scripts/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4e63658f803506d4d6496b7936537e39a5579f89 --- /dev/null +++ b/progress/SpecForge/scripts/prepare_data.py @@ -0,0 +1,679 @@ +import argparse +import json +import os +import random +import subprocess +from pathlib import Path +from typing import Dict, Tuple + +from tqdm import tqdm + +from datasets import concatenate_datasets, config, load_dataset + +""" +This script will convert the ultrachat/sharegpt dataset to the following schema in jsonl format: +{ + "id": str, + "conversations": [ + { + "role": str, + "content": str + } + ], +} +""" + +ROLE_MAPPING = { + "human": "user", + "gpt": "assistant", + "chatgpt": "assistant", + "bing": "assistant", + "bard": "assistant", +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + choices=[ + "ultrachat", + "sharegpt", + "eaglechat", + "perfectblend", + "perfectblend-llama3.1-8b-instruct", + "perfectblend-llama3.3-70b-instruct", + "perfectblend-llama4-scout-instruct", + "perfectblend-llama4-maverick-instruct", + "magpie-qwen2.5-pro-1m-v0.1", + "sharegpt4v", + "allava4v", + "opc", + "gsm8k", + "hendrycks_math", + "math_qa", + "codealpaca-20k", + "opencodeinstruct", + "magicoder-evol-instruct", + "sciq", + "camel", + ], + help="The demo dataset to quickly run the training for speculative decoding", + ) + parser.add_argument( + "--output-path", + type=str, + default=None, + help="The path to save the processed dataset, if not specified, the dataset will be saved in the cache/dataset/dataset_name directory of the root path", + ) + parser.add_argument( + "--data-path", + type=str, + default=None, + help="The path to the custom dataset, if not specified, the default dataset will be loaded", + ) + parser.add_argument( + "--sample-size", + type=int, + default=None, + help="The number of samples to process from the dataset, if not specified, all samples will be processed", + ) + parser.add_argument( + "--split-eval", + action="store_true", + help="Whether to split the dataset into train and eval sets, default is False", + ) + parser.add_argument( + "--opc-subset", + type=str, + default="largescale_diverse_instruct", + choices=[ + "largescale_diverse_instruct", + "filtered_infinity_instruct", + "realuser_instruct", + "all", + ], + help="The subset of OpenCoder opc-sft-stage1 dataset to use, or 'all' to use all subsets (default: largescale_diverse_instruct)", + ) + return parser.parse_args() + + +def get_cache_dir(dataset_name): + cache_dir = None + if dataset_name == "sharegpt4v": + raise ValueError("Downloading 'sharegpt4v' is not supported.") + elif dataset_name == "allava4v": + cache_dir = os.path.join( + config.HF_DATASETS_CACHE, "FreedomIntelligence", "ALLaVA" + ) + else: + raise ValueError( + f"Dataset '{dataset_name}' is not a supported VLM dataset for download." + ) + return cache_dir + + +def download_vlm_dataset(dataset_name: str) -> None: + """Download VLM's dataset such as sharegpt4v and allava4v""" + if dataset_name == "sharegpt4v": + raise Exception("Don't Support Download sharegpt4v.") + elif dataset_name == "allava4v": + cache_dir = get_cache_dir(dataset_name) + os.makedirs(cache_dir, exist_ok=True) + script_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "datasets", + "download_laion.sh", + ) + os.chmod(script_path, 0o755) + if not os.path.exists( + os.path.join(cache_dir, "allava_laion", "image_chunks", "images_0.zip") + ): + result = subprocess.run( + ["bash", script_path], + cwd=cache_dir, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Download image dataset failed: {result.stderr}") + print("##### allava4v dataset Download Complete #####") + else: + print("##### allava4v dataset has existed.") + else: + raise Exception(f"Don't support {dataset_name}") + + +def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the ultrachat dataset. + + The function expects a row with the following schema: + "messages": [ + { + "role": "user" | "assistant", + "content": str + } + ] + """ + conversations = row["messages"] + formatted_conversations = [] + for message in conversations: + role = message["role"] + content = message["content"] + assert role in ["user", "assistant"] + formatted_conversations.append({"role": role, "content": content}) + row = {"id": row["prompt_id"], "conversations": formatted_conversations} + return row, 0 + + +def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """ + sharegpt dataset schema: + { + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + """ + conversations = row["conversations"] + formatted_conversations = [] + skipped_count = 0 + for message in conversations: + if message["from"] not in ROLE_MAPPING: + skipped_count += 1 + continue + new_role = ROLE_MAPPING[message["from"]] + content = message["value"] + formatted_conversations.append({"role": new_role, "content": content}) + + row = {"id": row["id"], "conversations": formatted_conversations} + return row, skipped_count + + +def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict: + """ + sharegpt4v dataset schema: + { + "id": str, + "image": str, # path to the image + "conversations": [ + { + "from": , + "value": , + }, + ... + ] + } + """ + cache_dir = get_cache_dir(dataset_name) + conversations = row["conversations"] + image = os.path.join(cache_dir, row["image"]) + if not os.path.exists(image): + print(f"Image path {image} does not exist, skipping this sample.") + return None, None + formatted_conversations = [] + skipped_count = 0 + for message in conversations: + if message["from"] not in ROLE_MAPPING: + skipped_count += 1 + continue + new_role = ROLE_MAPPING[message["from"]] + if new_role == "user": + text_content = message["value"].replace("\n", "") + content = text_content + else: + content = message["value"] + formatted_conversations.append({"role": new_role, "content": content}) + + row = {"id": row["id"], "image": image, "conversations": formatted_conversations} + return row, skipped_count + + +def load_dataset_from_path(data_path: Path): + suffix = data_path.suffix.split(".")[1] + ds = load_dataset(suffix, data_files=str(data_path), split="train") + return ds + + +def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name): + train_output_jsonl_path = output_path.joinpath(f"{dataset_name}_train.jsonl") + if train_output_jsonl_path.exists(): + print( + f"The dataset {dataset_name} has already been processed and saved in {train_output_jsonl_path}, skipping..." + ) + return + + total_skipped_count = 0 + with open(train_output_jsonl_path, "w") as f: + for item in tqdm(train_ds, desc=f"Processing {dataset_name} dataset"): + if proc_fn is not None: + row, skipped_count = proc_fn(item, dataset_name) + if row is None: + continue + total_skipped_count += skipped_count + else: + row = item + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + if test_ds is not None: + test_output_jsonl_path = output_path.joinpath(f"{dataset_name}_test.jsonl") + with open(test_output_jsonl_path, "w") as f: + for item in tqdm(test_ds, desc=f"Processing {dataset_name} test dataset"): + if proc_fn is not None: + row, skipped_count = proc_fn(item, dataset_name) + if row is None: + continue + total_skipped_count += skipped_count + else: + row = item + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + if total_skipped_count > 0: + total_messages = len(train_ds) + (len(test_ds) if test_ds is not None else 0) + print( + f"Skipped {total_skipped_count}/{total_messages} messages for {dataset_name}" + ) + + +import hashlib + + +def process_opc_sft_stage1(row: Dict) -> Tuple[Dict, int]: + row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["instruction"]}, + {"role": "assistant", "content": row["output"]}, + ], + } + return processed_row, 0 + + +def process_codealpaca_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the CodeAlpaca-20k dataset. + + The function expects a row with the following schema: + { + "instruction": str, + "input": str, + "output": str + } + """ + row_id = hashlib.md5((row["instruction"] + row["output"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["instruction"]}, + {"role": "assistant", "content": row["output"]}, + ], + } + return processed_row, 0 + + +def process_opencodeinstruct_row( + row: Dict, dataset_name: str = None +) -> Tuple[Dict, int]: + """Process a row from the nvidia/OpenCodeInstruct dataset. + + The function expects a row with the following schema: + { + "id": str, + "input": str, + "output": str, + "domain": str, + "generation_algorithm": str, + "llm_judgement": str, + "unit_tests": str, + "tests_execution_status": str, + "average_test_score": float + } + """ + # Use the existing id if available, otherwise generate one + row_id = row.get("id") + if row_id is None: + row_id = hashlib.md5((row["input"] + row["output"]).encode()).hexdigest() + + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["input"]}, + {"role": "assistant", "content": row["output"]}, + ], + } + return processed_row, 0 + + +def process_magicoder_evol_instruct_row( + row: Dict, dataset_name: str = None +) -> Tuple[Dict, int]: + """Process a row from the ise-uiuc/Magicoder-Evol-Instruct-110K dataset. + + The function expects a row with the following schema: + { + "instruction": str, + "response": str + } + """ + row_id = hashlib.md5((row["instruction"] + row["response"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["instruction"]}, + {"role": "assistant", "content": row["response"]}, + ], + } + return processed_row, 0 + + +def process_gsm8k_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the gsm8k dataset. + + The function expects a row with the following schema: + { + "question": str, + "answer": str + } + """ + row_id = hashlib.md5((row["question"] + row["answer"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["question"]}, + {"role": "assistant", "content": row["answer"]}, + ], + } + return processed_row, 0 + + +def process_hendrycks_math_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the hendrycks_math dataset. + + The function expects a row with the following schema: + { + "problem": str, + "solution": str, + "level": str, + "type": str + } + """ + row_id = hashlib.md5((row["problem"] + row["solution"]).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": row["problem"]}, + {"role": "assistant", "content": row["solution"]}, + ], + } + return processed_row, 0 + + +def process_math_qa_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the allenai/math_qa dataset. + + The function expects a row with the following schema: + { + "Problem": str, + "Rationale": str, + "options": str, # format: "a) option1 b) option2 c) option3 d) option4" + "correct": str, + "annotated_formula": str, + "linear_formula": str, + "category": str + } + """ + # Combine Problem and options as user input + problem = row["Problem"] + options = row["options"] + user_content = f"{problem}\n{options}" + + # Use Rationale as assistant response + rationale = row["Rationale"] + + row_id = hashlib.md5((user_content + rationale).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": rationale}, + ], + } + return processed_row, 0 + + +def process_sciq_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the allenai/sciq dataset. + + The function expects a row with the following schema: + { + "question": str, + "distractor3": str, + "distractor1": str, + "distractor2": str, + "correct_answer": str, + "support": str + } + """ + question = row["question"] + correct_answer = row["correct_answer"] + distractor1 = row["distractor1"] + distractor2 = row["distractor2"] + distractor3 = row["distractor3"] + support = row["support"] + + # Create a list of all answers and randomly shuffle them + answers_list = [distractor3, distractor1, distractor2, correct_answer] + random.shuffle(answers_list) + + # Assign shuffled answers to labels a, b, c, d + labels = ["a", "b", "c", "d"] + options_list = [(labels[i], answers_list[i]) for i in range(4)] + + # Find the correct answer label after shuffling + correct_label = None + for label, answer in options_list: + if answer == correct_answer: + correct_label = label + break + + # Format options as a string + options_text = "\n".join([f"{label}) {answer}" for label, answer in options_list]) + user_content = f"{question}\n{options_text}" + + # Combine support with answer + assistant_content = f"{support}\nanswer: {correct_label}) {correct_answer}" + + row_id = hashlib.md5((user_content + assistant_content).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": assistant_content}, + ], + } + return processed_row, 0 + + +def process_camel_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the camel-ai dataset. + + The function expects a row with the following schema: + { + "message_1": str, # user message + "message_2": str, # assistant message + } + """ + message_1 = row["message_1"] + message_2 = row["message_2"] + + row_id = hashlib.md5((message_1 + message_2).encode()).hexdigest() + processed_row = { + "id": row_id, + "conversations": [ + {"role": "user", "content": message_1}, + {"role": "assistant", "content": message_2}, + ], + } + return processed_row, 0 + + +def add_index(row, idx) -> Dict: + row["id"] = idx + return row + + +def main(): + args = parse_args() + # load dataset + if args.dataset == "ultrachat": + ds = load_dataset("HuggingFaceH4/ultrachat_200k")["train_sft"] + proc_fn = process_ultrachat_row + elif args.dataset == "sharegpt": + if args.data_path is None: + ds = load_dataset("Aeala/ShareGPT_Vicuna_unfiltered")["train"] + else: + print("Loading dataset from custom data path: ", args.data_path) + ds = load_dataset_from_path(Path(args.data_path)) + proc_fn = process_sharegpt_row + elif args.dataset == "eaglechat": + ds = load_dataset("zhaode/EagleChat")["train"] + proc_fn = lambda row, name: (row, 0) + elif args.dataset == "perfectblend": + ds = load_dataset("mlabonne/open-perfectblend")["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = process_sharegpt_row + elif args.dataset == "perfectblend-llama3.1-8b-instruct": + ds = load_dataset("frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct")[ + "train" + ] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama3.3-70b-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama4-scout-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "perfectblend-llama4-maverick-instruct": + ds = load_dataset( + "frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct" + )["train"] + ds = ds.map(add_index, with_indices=True) + proc_fn = None + elif args.dataset == "magpie-qwen2.5-pro-1m-v0.1": + ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"] + ds = ds.rename_column("uuid", "id") + proc_fn = process_sharegpt_row + elif args.dataset == "sharegpt4v": + ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"] + raise Exception("Not supported sharegpt4v now") + download_vlm_dataset(args.dataset) + proc_fn = process_sharegpt4v_row + elif args.dataset == "allava4v": + ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[ + "instruct" + ] + download_vlm_dataset(args.dataset) + proc_fn = process_sharegpt4v_row + elif args.dataset == "opc": + if args.opc_subset == "all": + # Load all subsets and concatenate them + subsets = [ + "largescale_diverse_instruct", + "filtered_infinity_instruct", + "realuser_instruct", + ] + datasets_list = [ + load_dataset("OpenCoder-LLM/opc-sft-stage1", subset)["train"] + for subset in subsets + ] + ds = concatenate_datasets(datasets_list) + else: + ds = load_dataset("OpenCoder-LLM/opc-sft-stage1", args.opc_subset)["train"] + proc_fn = process_opc_sft_stage1 + elif args.dataset == "gsm8k": + ds = load_dataset("openai/gsm8k", "main")["train"] + proc_fn = process_gsm8k_row + elif args.dataset == "hendrycks_math": + # Load all subjects and concatenate them + subjects = [ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", + ] + datasets_list = [ + load_dataset("EleutherAI/hendrycks_math", subject)["train"] + for subject in subjects + ] + ds = concatenate_datasets(datasets_list) + proc_fn = process_hendrycks_math_row + elif args.dataset == "math_qa": + ds = load_dataset("allenai/math_qa", trust_remote_code=True)["train"] + proc_fn = process_math_qa_row + elif args.dataset == "codealpaca-20k": + ds = load_dataset("sahil2801/CodeAlpaca-20k", trust_remote_code=True)["train"] + proc_fn = process_codealpaca_row + elif args.dataset == "opencodeinstruct": + ds = load_dataset("nvidia/OpenCodeInstruct", trust_remote_code=True)["train"] + proc_fn = process_opencodeinstruct_row + elif args.dataset == "magicoder-evol-instruct": + ds = load_dataset( + "ise-uiuc/Magicoder-Evol-Instruct-110K", trust_remote_code=True + )["train"] + proc_fn = process_magicoder_evol_instruct_row + elif args.dataset == "sciq": + ds = load_dataset("allenai/sciq", trust_remote_code=True)["train"] + proc_fn = process_sciq_row + elif args.dataset == "camel": + # Load all three camel-ai datasets and concatenate them + camel_datasets = [ + load_dataset("camel-ai/biology", split="train"), + load_dataset("camel-ai/chemistry", split="train"), + load_dataset("camel-ai/physics", split="train"), + ] + ds = concatenate_datasets(camel_datasets) + proc_fn = process_camel_row + else: + raise ValueError( + f"This script only supports ultrachat, sharegpt, sharegpt4v, allava4v, opc, gsm8k, hendrycks_math, math_qa, codealpaca-20k, opencodeinstruct, magicoder-evol-instruct, sciq, camel, and perfect-blend-gptoss-20B datasets for demo purpose, if you wish to use other datasets, please modify this script." + ) + # filter and split dataset + if args.sample_size is not None and args.sample_size < len(ds): + ds = ds.select(range(args.sample_size)) + print(f"Processing {args.sample_size} samples from the dataset {args.dataset}") + if args.split_eval: + ds = ds.train_test_split(test_size=0.05) + train_ds = ds["train"] + test_ds = ds["test"] + else: + train_ds = ds + test_ds = None + + if args.output_path is None: + root_path = Path(__file__).parent.parent + output_path = root_path.joinpath("cache", "dataset") + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path(args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + + process_and_save_ds(train_ds, test_ds, output_path, proc_fn, args.dataset) + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/prepare_hidden_states.py b/progress/SpecForge/scripts/prepare_hidden_states.py new file mode 100644 index 0000000000000000000000000000000000000000..304677d265bd7ab6cde6c3c252d5bdebbebb4f56 --- /dev/null +++ b/progress/SpecForge/scripts/prepare_hidden_states.py @@ -0,0 +1,716 @@ +""" +This script will generate the hidden states for the dataset use transformer as the target model backend. +By generating hidden states in advance, we can avoid: +- the memory overhead of loading target model +- the latency overhead of generating hidden states for each request. + +Optimized for lower memory usage and higher efficiency. + +Usage: +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/sharegpt_train.jsonl \ + --output-path ./cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --chat-template llama3 \ + --max-length 2048 \ + --tp-size 1 \ + --batch-size 32 \ + --num-samples 1000 \ + --output-path ./cache/hidden_states + +For pre-formatted data (with chat template already applied), add --is-preformatted: +torchrun --nproc_per_node=8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/preformatted_data.jsonl \ + --output-path ./cache/hidden_states \ + --chat-template llama3 \ + --is-preformatted \ + --max-length 2048 +""" + +import argparse +import gc +import gzip +import hashlib +import os +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from tqdm import tqdm +from transformers import AutoConfig, AutoProcessor, AutoTokenizer + +from datasets import Dataset +from specforge.args import SGLangBackendArgs +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import ( + destroy_distributed, + get_dp_group, + get_tp_group, + init_distributed, + is_tp_rank_0, +) +from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model +from specforge.utils import ( + print_args_with_dots, + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) + + +@dataclass +class DataPoint: + input_ids: torch.Tensor + loss_mask: torch.Tensor + hidden_state: torch.Tensor + aux_hidden_state: Optional[torch.Tensor] = None + + +def parse_args(): + parser = argparse.ArgumentParser() + + # model-related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code when loading models", + ) + model_group.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) + model_group.add_argument("--enable-aux-hidden-states", action="store_true") + model_group.add_argument("--aux-hidden-states-layers", type=str, default=None) + + data_group = parser.add_argument_group("data") + data_group.add_argument("--data-path", type=str, required=True) + data_group.add_argument("--max-length", type=int, default=2048) + data_group.add_argument("--chat-template", type=str, default="llama3") + data_group.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) + data_group.add_argument("--num-samples", type=int, default=None) + data_group.add_argument("--build-dataset-num-proc", type=int, default=8) + + inference_group = parser.add_argument_group("inference") + inference_group.add_argument("--tp-size", type=int, default=1) + inference_group.add_argument("--batch-size", type=int, default=32) + + others_group = parser.add_argument_group("others") + others_group.add_argument("--cache-dir", type=str, default="./cache") + others_group.add_argument("--output-path", type=str, default=None) + others_group.add_argument( + "--model-download-dir", + type=str, + default=None, + help="The directory to download the target model to", + ) + others_group.add_argument( + "--dist-timeout", + type=int, + default=2000, + help="Timeout for collective communication in minutes, default to 2000 so that it does not go timeout", + ) + others_group.add_argument( + "--num-io-threads", + type=int, + default=None, + help="Number of threads for async I/O operations (default: all of CPU cores).", + ) + others_group.add_argument( + "--num-workers", type=int, default=4, help="Number of workers for DataLoader" + ) + others_group.add_argument( + "--io-queue-size", + type=int, + default=50, + help="Max number of pending I/O futures.", + ) + others_group.add_argument( + "--file-group-size", + type=int, + default=2000, + help="Number of files per subdirectory.", + ) + others_group.add_argument( + "--compress", + action="store_true", + help="Compress hidden state files on disk (gzip).", + ) + others_group.add_argument( + "--compression-level", + type=int, + default=6, + help="Gzip compression level (1-9).", + ) + + sglang_group = parser.add_argument_group("sglang") + SGLangBackendArgs.add_args(sglang_group) + return parser.parse_args() + + +def build_target_model( + args: argparse.Namespace, model_config: AutoConfig +) -> Tuple[Eagle3TargetModel, Optional[AutoProcessor]]: + """ + Build the target model according to the arguments. + + For VLM models (Qwen2.5-VL) without TP, load directly from transformers. + Otherwise, use the Eagle3 target model wrapper. + """ + if args.is_vlm and model_config.model_type == "qwen2_5_vl" and args.tp_size == 1: + # TODO: replace with sglang + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=( + model_config.dtype + if hasattr(model_config, "dtype") + else model_config.torch_dtype + ), + ) + .eval() + .cuda() + ) + else: + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + target_model = get_eagle3_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend="sglang", # we set this as the default backend to minimize precision mismatch in training and serving + torch_dtype=( + model_config.dtype + if hasattr(model_config, "dtype") + else model_config.torch_dtype + ), + device="cuda", + cache_dir=args.model_download_dir, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) + # Set auxiliary hidden states layers if specified + target_model.set_aux_hidden_states_layers(args.aux_hidden_states_layers) + + if args.is_vlm: + processor = AutoProcessor.from_pretrained(args.target_model_path) + else: + processor = None + + return target_model, processor + + +class HiddenStatesGenerator: + """ + This is a generator for creating and saving the hidden states based on the target model. + It includes the following features: + 1. Fixes a potential deadlock in TP > 1 scenarios when a batch is skipped. + 2. Implements a context manager (`with` statement) for robust resource handling. + 3. Makes internal settings (like queue sizes, group sizes) configurable. + 4. Centralizes resource cleanup logic. + """ + + def __init__( + self, + target_model, + enable_aux_hidden_states: bool = True, + num_io_threads: int = 4, + io_queue_size: int = 50, + file_group_size: int = 2000, + compress: bool = False, + compression_level: int = 6, + ): + """ + Args: + target_model: The model for inference. + enable_aux_hidden_states: Whether to save auxiliary hidden states. + num_io_threads: Number of threads for async I/O. + io_queue_size: Max number of pending I/O futures before cleanup. + file_group_size: Number of files per subdirectory. + """ + self.model = target_model + self.enable_aux_hidden_states = enable_aux_hidden_states + + # --- Configurable parameters --- + self.num_io_threads = num_io_threads + self.io_queue_size = io_queue_size + self.file_group_size = file_group_size + self.compress = compress + self.compression_level = compression_level + self.file_extension = ".ckpt.gz" if self.compress else ".ckpt" + + # progress bar should only shown on TP rank = 0 + self.show_progress = dist.get_rank(get_tp_group()) == 0 + + # --- REFACTOR: Thread pool is now managed by __enter__ and __exit__ --- + self.io_executor = None + self.pending_futures = [] + + def __enter__(self): + """Initializes resources when entering a 'with' block.""" + if is_tp_rank_0(): + self.io_executor = ThreadPoolExecutor(max_workers=self.num_io_threads) + self.pending_futures = [] + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleans up resources when exiting a 'with' block.""" + if is_tp_rank_0() and self.io_executor is not None: + if self.show_progress: + print("\nWaiting for all async I/O operations to complete...") + self._wait_all_saves() + self.io_executor.shutdown(wait=True) + self.io_executor = None # Reset for safety + + # Final barrier to ensure all processes exit generate() cleanly + dist.barrier() + + def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None: + """ + Save a data point to a file synchronously. If there is any NaN value in the data, this datapoint will be skipped. + + Args: + data_point (DataPoint): The data point to save. + output_file (str): The path to the output file. + """ + if data_point.hidden_state is not None and torch.any( + torch.isnan(data_point.hidden_state) + ): + print( + f"Warning: NaN found in hidden_state for {output_file}. Skipping save." + ) + return + + if data_point.aux_hidden_state is not None and torch.any( + torch.isnan(data_point.aux_hidden_state) + ): + print( + f"Warning: NaN found in aux_hidden_state for {output_file}. Skipping save." + ) + return + + if self.compress: + with gzip.open( + output_file, "wb", compresslevel=self.compression_level + ) as f: + torch.save(asdict(data_point), f) + else: + torch.save(asdict(data_point), output_file) + + def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None: + """ + Submit a job to the io_executor to save the data point asynchronously. + + Args: + data_point (DataPoint): The data point to save. + output_file (str): The path to the output file. + """ + assert is_tp_rank_0(), "Only tp_rank=0 should call _save_tensor_async" + # If the queue of pending save operations is full, we must wait. + if len(self.pending_futures) >= self.io_queue_size: + # First, try to clear any futures that have already finished without waiting. + self.pending_futures = [f for f in self.pending_futures if not f.done()] + # If the queue is *still* full, it means all I/O threads are busy and we have + # a backlog. We must now block the main generation loop and wait for the + # oldest I/O operation to complete before proceeding. + if len(self.pending_futures) >= self.io_queue_size: + self.pending_futures.pop(0).result() + + future = self.io_executor.submit( + self._save_tensor_sync, data_point, output_file + ) + self.pending_futures.append(future) + + def _wait_all_saves(self): + """ + This method is to ensure that all submitted jobs are completed. + """ + if is_tp_rank_0() and self.pending_futures: + for future in tqdm( + self.pending_futures, + desc="Finalizing Writes", + disable=not self.show_progress, + ): + future.result() # Wait and raise exception if any + self.pending_futures.clear() + + def _prepare_output_dirs( + self, output_path: str, start_idx: int, total_samples: int + ) -> None: + """ + The dataset is organized into groups of files, each group has a folder which contains the files for this group. For example, if the + file_group_size is 2000, the 0-1999 samples will be saved in the folder "rows_0-2000", the 2000-3999 samples will be saved in the folder "rows_2000-4000", etc. + + Args: + output_path (str): The path to the output directory. + start_idx (int): The starting index of the samples to save. + total_samples (int): The total number of samples to save. + + Returns: + None + """ + if not is_tp_rank_0() or total_samples == 0: + return + start_group = (start_idx // self.file_group_size) * self.file_group_size + end_sample_idx = start_idx + total_samples - 1 + end_group = (end_sample_idx // self.file_group_size) * self.file_group_size + for group_start_idx in range(start_group, end_group + 1, self.file_group_size): + grouped_subdir = ( + f"rows_{group_start_idx}-{group_start_idx + self.file_group_size}" + ) + output_dir = os.path.join(output_path, grouped_subdir) + os.makedirs(output_dir, exist_ok=True) + + def _check_existing_files_batch( + self, output_path: str, global_indices: List[int] + ) -> List[bool]: + """ + A helper function to check if the files for the given global indices exist. + + Args: + output_path (str): The path to the output directory. + global_indices (List[int]): The global indices of the samples to check. + + Returns: + List[bool]: A list of booleans indicating if the files for the given global indices exist. + """ + if not is_tp_rank_0(): + return [False] * len(global_indices) + + def check_single_file(idx): + if os.path.exists(self._get_file_path(output_path, idx)): + return True + legacy_ckpt = self._get_file_path(output_path, idx, extension=".ckpt") + compressed_ckpt = self._get_file_path( + output_path, idx, extension=".ckpt.gz" + ) + return os.path.exists(legacy_ckpt) or os.path.exists(compressed_ckpt) + + # Parallel file existence check + with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor: + exists = list(executor.map(check_single_file, global_indices)) + return exists + + def _get_file_path( + self, output_path: str, idx: int, extension: Optional[str] = None + ) -> str: + """ + A helper function to get the standard file path for the data point with the given index. + + Args: + output_path (str): The path to the output directory. + idx (int): The global index of the data point. + + Returns: + str: The file path for the data point. + """ + ext = self.file_extension if extension is None else extension + group_idx = (idx // self.file_group_size) * self.file_group_size + grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}" + return os.path.join(output_path, grouped_subdir, f"data_{idx}{ext}") + + @torch.no_grad() + def generate( + self, + data_loader: torch.utils.data.DataLoader, + output_path: str, + start_idx: int = 0, + samples_per_dp: int = 0, + ): + """ + This version prioritizes minimal CPU RAM usage above all else, even at the cost of performance. + - It processes samples one-by-one within the tp_rank_0 process. + - It avoids batching GPU-to-CPU transfers. + - It ensures only one sample's data is in RAM for I/O at any given time. + """ + self._prepare_output_dirs(output_path, start_idx, samples_per_dp) + + tp_group = get_tp_group() + tp_group_ranks = dist.get_process_group_ranks(tp_group) + tp_rank_0_global = tp_group_ranks[0] + global_idx = start_idx + + progress_bar = tqdm( + data_loader, + disable=(not self.show_progress), + desc="Generating Hidden States", + position=dist.get_rank(get_dp_group()), + leave=True, + ) + + total_skipped, total_processed = 0, 0 + + for batch_idx, batch in enumerate(progress_bar): + batch_size = batch["input_ids"].size(0) + current_batch_indices = list(range(global_idx, global_idx + batch_size)) + + # # Step 1: Synchronize valid indices across TP group + # we check which files already exist and sync this info across TP ranks + # if exists, we will skip these samples + if is_tp_rank_0(): + exists_list = self._check_existing_files_batch( + output_path, current_batch_indices + ) + exists_tensor = torch.tensor( + exists_list, dtype=torch.bool, device="cuda" + ) + else: + exists_tensor = torch.tensor( + [False] * batch_size, dtype=torch.bool, device="cuda" + ) + dist.broadcast(exists_tensor, src=tp_rank_0_global, group=tp_group) + + # Step 1: TP rank 0 checks which samples need processing + valid_indices_in_batch = [ + i for i, exists in enumerate(exists_tensor) if not exists + ] + sample_global_indices = [ + current_batch_indices[i] for i in valid_indices_in_batch + ] + num_valid = len(valid_indices_in_batch) + total_skipped += batch_size - num_valid + + # Step 2: Filter batch before moving to GPU to save memory + global_idx += batch_size + filtered_batch = { + "input_ids": batch["input_ids"][valid_indices_in_batch], + "attention_mask": batch["attention_mask"][valid_indices_in_batch], + "loss_mask": batch["loss_mask"][valid_indices_in_batch], + } + del batch + if num_valid == 0: + # Data has already been generated, no sample processing, update progress bar. + if self.show_progress: + progress_bar.set_postfix( + { + "processed": total_processed, + "skipped": total_skipped, + "pending_io": ( + len(self.pending_futures) if is_tp_rank_0() else 0 + ), + } + ) + continue + + filtered_batch_gpu = { + k: v.cuda(non_blocking=True) for k, v in filtered_batch.items() + } + _, _, aux_hidden_states_list, last_hidden_states_list = self.model.extend( + **filtered_batch_gpu, + return_last_hidden_states=True, + return_logits=False, + ) + + del filtered_batch_gpu + + if is_tp_rank_0(): + for i, ( + current_global_idx, + aux_hidden_states, + last_hidden_states, + ) in enumerate( + zip( + sample_global_indices, + aux_hidden_states_list, + last_hidden_states_list, + ) + ): + + # Process ONE sample at a time to minimize CPU RAM footprint + # 1. Transfer only the required slice for one sample to CPU + aux_hidden_states = ( + aux_hidden_states.cpu().clone().unsqueeze(0) + if aux_hidden_states is not None + else None + ) + last_hidden_states = ( + last_hidden_states.cpu().clone().unsqueeze(0) + if last_hidden_states is not None + else None + ) + data_point = DataPoint( + input_ids=filtered_batch["input_ids"][i].clone(), + loss_mask=filtered_batch["loss_mask"][i].clone(), + hidden_state=last_hidden_states, + aux_hidden_state=aux_hidden_states, + ) + + # 3. Save asynchronously (the backpressure logic is still crucial) + output_file = self._get_file_path(output_path, current_global_idx) + self._save_tensor_async(data_point, output_file) + + # 4. Immediately clean up the single-sample CPU tensors + del last_hidden_states, aux_hidden_states + + total_processed += len(sample_global_indices) + + # Clean up the large GPU and CPU batch data + del aux_hidden_states_list, last_hidden_states_list, filtered_batch + + if batch_idx % 5 == 0: # Make GC and cache clearing more frequent + torch.cuda.empty_cache() + gc.collect() + + if self.show_progress: + progress_bar.set_postfix( + { + "processed": total_processed, + "skipped": total_skipped, + "pending_io": ( + len(self.pending_futures) if is_tp_rank_0() else 0 + ), + } + ) + + if self.show_progress: + print( + f"\nGeneration loop finished. Processed: {total_processed}, Skipped: {total_skipped}" + ) + dist.barrier() + + +def main(): + args = parse_args() + if args.aux_hidden_states_layers is not None: + args.aux_hidden_states_layers = [ + int(x) for x in args.aux_hidden_states_layers.split(",") + ] + if args.num_io_threads is None: + cpu_cores = os.cpu_count() or 1 + args.num_io_threads = max(1, cpu_cores) + # Initialize distributed environment (TP + DP) + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_args_with_dots(args) + + # Build target model (with TP) + target_model_config = AutoConfig.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + target_model, processor = build_target_model(args, target_model_config) + + print_with_rank( + f"DP Rank {dist.get_rank(get_dp_group())}, TP Rank {dist.get_rank(get_tp_group())}, " + f"DP Size {dist.get_world_size(get_dp_group())}, TP Size {dist.get_world_size(get_tp_group())}" + ) + + if args.output_path is None: + args.output_path = os.path.join( + Path(__file__).parent.parent, "cache", "hidden_states" + ) + + # Load complete dataset + assert os.path.exists( + args.data_path + ), f"Dataset path {args.data_path} does not exist" + dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.data_path}, + cache_dir=os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "cache", + "hf_dataset", + ), + ) + if args.num_samples is not None: + dataset = dataset.select(range(args.num_samples)) + # Tokenizer and cache key + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_path, trust_remote_code=True + ) + cache_params_string = f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}-{args.num_samples}-{args.is_preformatted}" + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + # Preprocess on complete, un-sharded dataset + with rank_0_priority(): + print_with_rank("Main process is building the dataset cache...") + eagle3_dataset = build_eagle3_dataset( + dataset=dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, + processor=processor, + num_proc=args.build_dataset_num_proc, + ) + print_with_rank(f"Dataset prepared with {len(eagle3_dataset)} samples.") + + # Create DP-sharded dataloader + data_loader = prepare_dp_dataloaders( + dataset=eagle3_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=False, + process_group=get_dp_group(), + is_vlm=args.is_vlm, + ) + + print_with_rank( + f"DataLoader created for DP Rank {dist.get_rank(get_dp_group())}. " + f"Number of batches: {len(data_loader)}" + ) + + # Calculate starting index and sample count for current DP rank + total = len(eagle3_dataset) + dp_rank = dist.get_rank(get_dp_group()) + dp_size = dist.get_world_size(get_dp_group()) + + # Calculate samples per DP rank (handle non-divisible case) + samples_per_dp = total // dp_size + remainder = total % dp_size + + # Earlier ranks handle one extra sample if there's a remainder + if dp_rank < remainder: + samples_per_dp += 1 + start_idx = dp_rank * samples_per_dp + else: + start_idx = dp_rank * samples_per_dp + remainder + + print_with_rank( + f"DP Rank {dp_rank} will process {samples_per_dp} samples, " + f"starting from index {start_idx}" + ) + + # Generate hidden states + try: + # Pass configurable arguments from args if needed + with HiddenStatesGenerator( + target_model, + enable_aux_hidden_states=args.enable_aux_hidden_states, + num_io_threads=args.num_io_threads, + io_queue_size=args.io_queue_size, + file_group_size=args.file_group_size, + compress=args.compress, + compression_level=args.compression_level, + # Other params like io_queue_size can also be added to argparse + ) as hidden_states_generator: + + # Generate hidden states + hidden_states_generator.generate( + data_loader, + output_path=args.output_path, + start_idx=start_idx, + samples_per_dp=samples_per_dp, + ) + + finally: + # The finally block ensures destroy_distributed is always called + print_with_rank("All hidden states generated or job finished.") + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/regenerate_train_data.py b/progress/SpecForge/scripts/regenerate_train_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0470c68cd6eeb6d25454c6b6064a719c439645 --- /dev/null +++ b/progress/SpecForge/scripts/regenerate_train_data.py @@ -0,0 +1,453 @@ +""" +This script will re-generate the dataset from target model, +which better aligns the draft model with the target model’s output distribution. + +Usage: +1. Set up one or more SGLang servers for the target model. + +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 128 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 + + +2. Regenerate the dataset using the `regenerate_train_data.py` script. +python scripts/regenerate_train_data.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --concurrency 128 \ + --max-tokens 4096 \ + --server-address localhost:30000 \ + --temperature 0.8 \ + --input-file-path ./cache/dataset/sharegpt_train.jsonl \ + --output-file-path ./cache/dataset/sharegpt_train_regen.jsonl +""" + +import argparse +import json +import os +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List + +from openai import OpenAI +from tqdm import tqdm + + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Re-generate training data using sglang model server" + ) + + # model related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--model", type=str, required=True) + model_group.add_argument( + "--is-reasoning-model", + action="store_true", + help="Whether the model is a reasoning model", + ) + model_group.add_argument( + "--is-gpt-oss", + action="store_true", + help="Whether the model is a GPT-OSS model", + ) + + # sampling params + sampling_params_group = parser.add_argument_group("sampling parameters") + sampling_params_group.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for sglang model server", + ) + sampling_params_group.add_argument( + "--top-p", + type=float, + default=None, + help="Nucleus sampling top_p", + ) + sampling_params_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling value sent via extra_body", + ) + sampling_params_group.add_argument( + "--repetition-penalty", + type=float, + default=None, + help="Mapped to presence_penalty in the OpenAI API", + ) + sampling_params_group.add_argument( + "--max-tokens", + type=int, + default=4096, + help="Maximum number of tokens (default: 4096)", + ) + + # optimization + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--concurrency", + type=int, + default=64, + help="The number of requests to send to a single server concurrently, the total number of concurrent requests is concurrency * number of server addresses", + ) + + # data related arguments + data_group = parser.add_argument_group("data") + data_group.add_argument( + "--input-file-path", type=str, required=True, help="Path to the input file" + ) + data_group.add_argument( + "--output-file-path", type=str, required=True, help="Path to the output file" + ) + data_group.add_argument( + "--num-samples", + type=int, + default=None, + help="The number of samples to regenerate, if not provided, all samples will be regenerated", + ) + data_group.add_argument( + "--resume", + action="store_true", + help="Resume from existing output file, skip already processed samples", + ) + + # sglang server + server_group = parser.add_argument_group("sglang server") + server_group.add_argument( + "--server-address", + type=str, + nargs="+", + help="Server address and port for sglang model server", + ) + return parser.parse_args() + + +def get_random_reasoning_effort() -> str: + """Get a random reasoning effort level for the model with weighted probabilities.""" + # usage example: https://huggingface.co/openai/gpt-oss-20b/discussions/28 + # Reasoning effort levels with weights: LOW(4), MEDIUM(4), HIGH(2) + reasoning_efforts = [ + "low", + "medium", + "high", + ] + weights = [4, 4, 2] + return random.choices(reasoning_efforts, weights=weights, k=1)[0] + + +def compute_context_length(conversations: List[Dict[str, Any]]) -> int: + """ + This is a rough estimate of the context length measured in untokenized + tokens. + """ + length = 0 + for message in conversations: + content = message.get("content") + if isinstance(content, str): + # {"role": "assistant", "content": "Hi, how can I help?"} + length += len(content.split()) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str): + length += len(text.split()) + return length + + +def build_query_kwargs(args, messages, max_tokens=None): + effective_max_tokens = max_tokens if max_tokens is not None else args.max_tokens + + query_kwargs = dict( + model=args.model, + messages=messages, + max_tokens=effective_max_tokens, + temperature=args.temperature, + stream=False, + ) + if args.top_p is not None: + query_kwargs["top_p"] = args.top_p + if args.repetition_penalty is not None: + query_kwargs["presence_penalty"] = args.repetition_penalty + extra_body = {} + if args.top_k is not None: + extra_body["top_k"] = args.top_k + if extra_body: + query_kwargs["extra_body"] = extra_body + if args.is_gpt_oss: + query_kwargs["reasoning_effort"] = get_random_reasoning_effort() + return query_kwargs + + +def call_sglang( + args, + server_address: str, + data: List[Dict[str, Any]], + max_tokens=None, +) -> str: + """Send a batch of prompts to sglang /v1/completions.""" + # client = OpenAI(base_url=f"http://{server_address}/v1", api_key="None") + client = OpenAI(base_url=f"http://{server_address}/v1", api_key="eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJpc3MiOiAiemlsaW4ta2V5LTEyMyIsICJpYXQiOiAxNzU3NTU0NzUxLCAiZXhwIjogMTc4OTA5MDc1MX0.2iTg0MSLlD92H-28LZ_PCQG7RQNeF5X1O6IYHp8_u-c", + timeout=1200) + + messages = data["conversations"] + regenerated_messages = [] + + # ignore data which starts with an assistant message + if messages[0]["role"] == "assistant": + data["status"] = "error" + data["error"] = "Data starts with an assistant message" + return data + + for message in messages: + if message["role"] == "system": + regenerated_messages.append(message) + elif message["role"] == "assistant": + continue + elif message["role"] == "user": + regenerated_messages.append(message) + + query_kwargs = build_query_kwargs(args, regenerated_messages, max_tokens) + + try: + resp = client.chat.completions.create(**query_kwargs) + except Exception as e: + data["status"] = "error" + data["error"] = str(e) + return data + response_text = resp.choices[0].message.content + resp_msg = { + "role": "assistant", + "content": response_text, + } + if args.is_reasoning_model: + resp_msg["thinking"] = resp.choices[0].message.reasoning_content + regenerated_messages.append(resp_msg) + else: + data["status"] = "error" + data["error"] = f"Invalid message role: {message['role']}" + return data + data["conversations"] = regenerated_messages + data["status"] = "success" + return data + + +def main(): + # Parse command line arguments + args = parse_arguments() + + # Validate parameters + if not (0.0 <= args.temperature <= 1.0): + raise ValueError("Temperature must be between 0.0 and 1.0") + + if args.max_tokens <= 0: + raise ValueError("Max tokens must be greater than 0") + + print(f"Configuration:") + print(f" Model path: {args.model}") + print(f" Max tokens: {args.max_tokens}") + print(f" Concurrency: {args.concurrency}") + print(f" Temperature: {args.temperature}") + print(f" API URL: {args.server_address}") + print(f" Input file: {args.input_file_path}") + print(f" Output file: {args.output_file_path}") + print(f" Resume mode: {args.resume}") + print("-" * 50) + total_lines = sum(1 for _ in open(args.input_file_path)) + + skip_lines = 0 + error_file_path = args.output_file_path.replace(".jsonl", "_error.jsonl") + + if args.resume and os.path.exists(args.output_file_path): + existing_success = sum(1 for _ in open(args.output_file_path)) + existing_error = 0 + if os.path.exists(error_file_path): + existing_error = sum(1 for _ in open(error_file_path)) + skip_lines = existing_success + existing_error + print(f"Resume mode enabled:") + print(f" Found {existing_success} successful samples in output file") + print(f" Found {existing_error} error samples in error file") + print(f" Skipping first {skip_lines} input samples") + print("-" * 50) + + if skip_lines >= total_lines: + print(f"All {total_lines} samples already processed. Nothing to do.") + return + + # test all server addresses + valid_server_addresses = [] + for server_address in args.server_address: + dummy_data = dict( + conversations=[{"role": "user", "content": "Hello, how are you?"}] + ) + result = call_sglang( + args, + server_address, + dummy_data, + max_tokens=1, + ) + if result is not None: + valid_server_addresses.append(server_address) + else: + print(f"Server {server_address} is not available") + + if len(valid_server_addresses) == 0: + raise ValueError("No server address is available") + print( + f"Using {len(valid_server_addresses)} server addresses: {valid_server_addresses}" + ) + print("-" * 50) + + # Determine file open mode based on resume flag + file_mode = "a" if (args.resume and skip_lines > 0) else "w" + print( + f"Regenerating dataset and saving the output to {args.output_file_path} and error log to {error_file_path}" + ) + print( + f"File open mode: {file_mode} ({'append' if file_mode == 'a' else 'overwrite'})" + ) + print("-" * 50) + context_token_sum = 0 + context_token_min = None + context_token_max = 0 + success_samples = 0 + error_samples = 0 + + # Create progress bar + with ( + open(args.input_file_path, "r") as input_file, + open(args.output_file_path, file_mode) as output_file_handle, + open(error_file_path, file_mode) as error_file_handle, + ): + executor = ThreadPoolExecutor( + max_workers=args.concurrency * len(valid_server_addresses) + ) + waiting_queue = { + server_address: [] for server_address in valid_server_addresses + } + pbar = tqdm(total=total_lines, desc="Processing", initial=skip_lines) + start_server_index = 0 + + if skip_lines > 0: + print(f"Skipping {skip_lines} already processed samples...") + for _ in range(skip_lines): + next(input_file, None) + print(f"Resuming from sample {skip_lines + 1}") + + for line in input_file: + if ( + args.num_samples is not None + and success_samples + error_samples >= args.num_samples + ): + break + + data = json.loads(line.strip()) + + # find server address with the least waiting requests + server_address = valid_server_addresses[start_server_index] + start_server_index = (start_server_index + 1) % len(valid_server_addresses) + + # submit prompt to sglang + while len(waiting_queue[server_address]) >= args.concurrency: + finished_on_request = False + # check if any future is done, if so, write the result to the output file + for req_future in waiting_queue[server_address]: + if req_future.done(): + regen_data = req_future.result() + + if regen_data["status"] == "error": + error_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + error_samples += 1 + else: + ctx_len = compute_context_length( + regen_data.get("conversations", []) + ) + context_token_sum += ctx_len + if context_token_min is None: + context_token_min = ctx_len + else: + context_token_min = min(context_token_min, ctx_len) + context_token_max = max(context_token_max, ctx_len) + + output_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + success_samples += 1 + waiting_queue[server_address].remove(req_future) + finished_on_request = True + + if finished_on_request: + break + + req_future = executor.submit( + call_sglang, + args, + server_address, + data, + ) + waiting_queue[server_address].append(req_future) + pbar.update(1) + + # deal with all the remaining requests + for server_address, waiting_queue_items in waiting_queue.items(): + for req_future in waiting_queue_items: + regen_data = req_future.result() + if regen_data["status"] == "error": + error_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + error_samples += 1 + else: + ctx_len = compute_context_length( + regen_data.get("conversations", []) + ) + context_token_sum += ctx_len + if context_token_min is None: + context_token_min = ctx_len + else: + context_token_min = min(context_token_min, ctx_len) + context_token_max = max(context_token_max, ctx_len) + + output_file_handle.write( + json.dumps(regen_data, ensure_ascii=False) + "\n" + ) + success_samples += 1 + + print(f"\nProcessing completed!") + if success_samples > 0: + avg_len = context_token_sum / success_samples + print("Context length statistics (token count over conversations):") + print(f"Number of successful examples: {success_samples}") + print(f"Shortest context length: {context_token_min}") + print(f"Longest context length: {context_token_max}") + print(f"Average context length: {avg_len:.2f}") + else: + print("No successful examples to compute context length statistics.") + + total_processed = success_samples + error_samples + if skip_lines > 0: + print(f"\nResume processing completed!") + print(f" Previously processed: {skip_lines}") + print( + f" Newly processed: {total_processed} ({success_samples} success, {error_samples} failed)" + ) + print(f" Total: {skip_lines + total_processed}") + else: + print( + f"\nProcessing completed! {success_samples} samples regenerated, {error_samples} samples failed." + ) + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/run_eval_dflash_lora.sh b/progress/SpecForge/scripts/run_eval_dflash_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..433a4ebb2d7a32d3bd5d50344fd7efaa47acd54c --- /dev/null +++ b/progress/SpecForge/scripts/run_eval_dflash_lora.sh @@ -0,0 +1,28 @@ +#!/bin/bash +ROOT_DIR=/workspace/hanrui/junquan/SpecForge +export PYTHONPATH=$ROOT_DIR:$PYTHONPATH +export PATH=/workspace/hanrui/specforge/bin:$PATH +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=16 +export PYTHONUNBUFFERED=1 +export TRANSFORMERS_OFFLINE=1 + +NUM_GPUS=${1:-8} + +/workspace/hanrui/specforge/bin/python3 -m torch.distributed.run \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/eval_dflash_lora.py \ + --model-path /workspace/models/Qwen3-8B \ + --ckpt-dir $ROOT_DIR/outputs/qwen3-8b-dflash-lora/epoch_2_step_218500 \ + --data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \ + --lora-config $ROOT_DIR/configs/qwen3-8b-dflash-lora.json \ + --block-size 16 \ + --max-length 2048 \ + --batch-size 1 \ + --attention-backend flex_attention \ + --lm-head-chunk-size 256 \ + --chat-template qwen \ + --log-interval 50 \ + --cache-dir $ROOT_DIR/cache \ + --dist-timeout 120 diff --git a/progress/SpecForge/scripts/run_train_dflash_lora.sh b/progress/SpecForge/scripts/run_train_dflash_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..c3d236d5783b1eacd8ce592722df3b20928a0f54 --- /dev/null +++ b/progress/SpecForge/scripts/run_train_dflash_lora.sh @@ -0,0 +1,32 @@ +#!/bin/bash +ROOT_DIR=/workspace/hanrui/junquan/SpecForge +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=16 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +export PATH=/workspace/hanrui/specforge/bin:$PATH +export PYTHONPATH=$ROOT_DIR:$PYTHONPATH + +NUM_GPUS=${1:-8} + +/workspace/hanrui/specforge/bin/python3 -m torch.distributed.run \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash_lora.py \ + --model-path /workspace/Qwen3-8B \ + --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-lora \ + --lora-config $ROOT_DIR/configs/qwen3-8b-dflash-lora.json \ + --block-size 16 \ + --max-length 2048 \ + --batch-size 1 \ + --num-epochs 3 \ + --learning-rate 2e-4 \ + --accumulation-steps 8 \ + --loss-decay-gamma 7 \ + --attention-backend flex_attention \ + --lm-head-chunk-size 256 \ + --gradient-checkpointing \ + --chat-template qwen \ + --log-interval 50 \ + --save-interval 500 \ + --cache-dir $ROOT_DIR/cache diff --git a/progress/SpecForge/scripts/train_dflash.py b/progress/SpecForge/scripts/train_dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf4f87139c7da50b174c2d178b8b61f1e5c0f35 --- /dev/null +++ b/progress/SpecForge/scripts/train_dflash.py @@ -0,0 +1,720 @@ +#!/usr/bin/env python3 +# coding=utf-8 +"""DFlash Training Script.""" + +import argparse +import logging +import math +import os +import shutil +import time +import warnings +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, AutoTokenizer + +from datasets import load_dataset +from specforge.args import SGLangBackendArgs, TrackerArgs +from specforge.core.dflash import OnlineDFlashModel +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import destroy_distributed, get_dp_group, get_tp_group, init_distributed +from specforge.modeling.draft.dflash import DFlashDraftModel +from specforge.modeling.target.dflash_target_model import ( + DFlashTargetModel, + get_dflash_target_model, +) +from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead +from specforge.optimizer import BF16Optimizer +from specforge.tracker import create_tracker +from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank + + +# ──────────────────────────────────────────────────────────────── +# Memory profiling utilities +# ──────────────────────────────────────────────────────────────── + +def _mb(bytes_val: int) -> float: + return bytes_val / 1024 ** 2 + + +def log_cuda_memory(tag: str, rank_only: int = 0) -> None: + """Print current / peak CUDA memory at a labelled checkpoint (rank 0 only).""" + if not torch.cuda.is_available(): + return + if dist.is_available() and dist.is_initialized() and dist.get_rank() != rank_only: + return + allocated = _mb(torch.cuda.memory_allocated()) + reserved = _mb(torch.cuda.memory_reserved()) + peak_alloc = _mb(torch.cuda.max_memory_allocated()) + peak_res = _mb(torch.cuda.max_memory_reserved()) + logging.getLogger(__name__).info( + f"[VRAM | {tag}] " + f"allocated={allocated:.1f} MB reserved={reserved:.1f} MB " + f"peak_alloc={peak_alloc:.1f} MB peak_res={peak_res:.1f} MB" + ) + + +def log_model_memory(name: str, model: torch.nn.Module, rank_only: int = 0) -> None: + """Print parameter + gradient memory for a given model (rank 0 only).""" + if dist.is_available() and dist.is_initialized() and dist.get_rank() != rank_only: + return + param_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) + grad_bytes = sum( + p.grad.numel() * p.grad.element_size() + for p in model.parameters() + if p.grad is not None + ) + logging.getLogger(__name__).info( + f"[MODEL MEM | {name}] " + f"params={_mb(param_bytes):.1f} MB " + f"grads={_mb(grad_bytes):.1f} MB " + f"total={_mb(param_bytes + grad_bytes):.1f} MB" + ) + + +def log_optimizer_memory(name: str, optimizer, rank_only: int = 0) -> None: + """Estimate optimizer state memory (rank 0 only).""" + if dist.is_available() and dist.is_initialized() and dist.get_rank() != rank_only: + return + state_bytes = 0 + for state in optimizer.optimizer.state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + state_bytes += v.numel() * v.element_size() + logging.getLogger(__name__).info( + f"[OPT MEM | {name}] optimizer_states={_mb(state_bytes):.1f} MB" + ) + + +def log_tensor_memory(name: str, tensor: torch.Tensor, rank_only: int = 0) -> None: + """Print memory of a single tensor (rank 0 only).""" + if dist.is_available() and dist.is_initialized() and dist.get_rank() != rank_only: + return + mb = _mb(tensor.numel() * tensor.element_size()) + logging.getLogger(__name__).info( + f"[TENSOR | {name}] shape={tuple(tensor.shape)} dtype={tensor.dtype} size={mb:.1f} MB" + ) + + +def reset_peak_memory() -> None: + """Reset CUDA peak memory stats.""" + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + +# ──────────────────────────────────────────────────────────────── + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train DFlash Draft Model") + + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--target-model-backend", + type=str, + default="hf", + choices=["sglang", "hf"], + help="Backend for target model: 'sglang' (service) or 'hf' (local)", + ) + model_group.add_argument("--draft-config-path", type=str, default=None) + model_group.add_argument("--block-size", type=int, default=16) + model_group.add_argument("--num-draft-layers", type=int, default=1) + model_group.add_argument( + "--mask-token-id", + type=int, + default=None, + help="MASK token ID. If not provided, auto-detect from tokenizer.", + ) + model_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + choices=["eager", "sdpa", "flex_attention"], + help="Attention backend for draft model.", + ) + model_group.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code" + ) + model_group.add_argument( + "--random-anchor", + action="store_true", + help="Enable random anchor sampling for block construction (paper Sec 4.2).", + ) + model_group.add_argument( + "--num-anchors", + type=int, + default=512, + help="Number of anchor positions per sequence when --random-anchor is set.", + ) + model_group.add_argument( + "--loss-decay-gamma", + type=float, + default=None, + help="Gamma for exponential loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables.", + ) + + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="qwen") + dataset_group.add_argument("--is-preformatted", action="store_true") + dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) + dataset_group.add_argument( + "--build-dataset-num-proc", + type=int, + default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8)), + ) + + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=6) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=6e-4) + training_group.add_argument("--max-length", type=int, default=3072) + training_group.add_argument("--warmup-ratio", type=float, default=0.04) + training_group.add_argument("--max-grad-norm", type=float, default=1.0) + training_group.add_argument("--accumulation-steps", type=int, default=1) + training_group.add_argument( + "--optimizer-type", + type=str, + default="adamw", + choices=["adamw", "adamw_8bit", "apollo"], + help="Optimizer type (default: adamw)", + ) + training_group.add_argument( + "--optimizer-config", + type=str, + default=None, + help="Path to optimizer config JSON file (required for apollo)", + ) + training_group.add_argument( + "--no-fp32-params", + action="store_true", + help="Disable FP32 master copy of parameters to save memory", + ) + training_group.add_argument( + "--gradient-checkpointing", + action="store_true", + help="Enable gradient checkpointing to save memory (trades compute for memory)", + ) + training_group.add_argument("--seed", type=int, default=42) + training_group.add_argument("--resume", action="store_true") + training_group.add_argument( + "--ckpt-dir", + type=str, + default=None, + help="Directory of the checkpoint to resume training from", + ) + + output_group = parser.add_argument_group("output") + output_group.add_argument("--output-dir", type=str, required=True) + output_group.add_argument("--cache-dir", type=str, default="./cache") + output_group.add_argument("--log-interval", type=int, default=50) + output_group.add_argument("--eval-interval", type=int, default=1000) + output_group.add_argument("--save-interval", type=int, default=1000) + + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--tp-size", + type=int, + default=1, + help="The size of the tensor parallel for the target model", + ) + optimization_group.add_argument( + "--lm-head-chunk-size", + type=int, + default=0, + help="Chunk size for lm_head + CE loss computation. " + "When > 0, processes sequence in chunks to avoid materializing " + "full [bsz, seq_len, vocab_size] logits tensor. " + "Recommended: 256-1024 for large vocab models. 0 disables chunking.", + ) + + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + dist_group = parser.add_argument_group("distributed") + dist_group.add_argument("--dist-timeout", type=int, default=30) + + # SGLang specific args + sglang_group = parser.add_argument_group("sglang backend") + SGLangBackendArgs.add_args(sglang_group) + + return parser.parse_args() + + +def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: + """Build target model (backend wrapper) and draft model.""" + print_on_rank0( + f"Loading target model from {args.target_model_path} using {args.target_model_backend} backend" + ) + + # 1. Build Target Model Wrapper + target_model_kwargs = {} + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + + target_model = get_dflash_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device="cuda" if args.target_model_backend == "hf" else None, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) + + # 2. Build Draft Model + if args.draft_config_path: + draft_config = AutoConfig.from_pretrained(args.draft_config_path) + print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + else: + target_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config.num_hidden_layers = args.num_draft_layers + draft_config.block_size = args.block_size + draft_config.num_target_layers = target_config.num_hidden_layers + print_on_rank0("Auto-generated draft config from target model") + + if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: + draft_config.dflash_config = {} + + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + + target_model.set_capture_layers(draft_model.target_layer_ids) + + print_on_rank0( + f"Draft config: block_size={draft_config.block_size}, " + f"num_hidden_layers={draft_config.num_hidden_layers}, " + f"num_target_layers={draft_config.num_target_layers}" + ) + print_on_rank0( + f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}" + ) + + # ── Memory checkpoint: after model loading ── + log_cuda_memory("after build_models") + if hasattr(target_model, "model"): + log_model_memory("target_model", target_model.model) + log_model_memory("draft_model", draft_model) + + return target_model, draft_model + + +def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: + """Build train and eval dataloaders.""" + import hashlib + + cache_params_string = ( + f"{args.train_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + + min_loss_tokens = 2 * args.block_size + original_size = len(train_eagle3_dataset) + train_eagle3_dataset = train_eagle3_dataset.filter( + lambda x: x["loss_mask"].sum() >= min_loss_tokens + ) + print_on_rank0( + f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples" + ) + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=get_dp_group(), + ) + + eval_dataloader = None + if args.eval_data_path: + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_eagle3_dataset = build_eagle3_dataset( + dataset=eval_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=get_dp_group(), + ) + + return train_dataloader, eval_dataloader + + +def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer): + """Save checkpoint.""" + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(dflash_model, StateDictType.FULL_STATE_DICT): + state_dict = dflash_model.state_dict() + draft_state_dict = { + k.replace("draft_model.", ""): v + for k, v in state_dict.items() + if "draft_model." in k + } + + if dist.get_rank() == 0: + torch.save( + { + "epoch": epoch, + "global_step": step, + "args": args, + **optimizer.state_dict(), + }, + os.path.join(save_dir, "training_state.pt"), + ) + + draft_model.save_pretrained(save_dir, state_dict=draft_state_dict) + + modeling_src = os.path.join( + os.path.dirname(__file__), + "..", + "specforge", + "modeling", + "draft", + "dflash.py", + ) + modeling_dst = os.path.join(save_dir, "dflash.py") + if os.path.exists(modeling_src): + shutil.copy(modeling_src, modeling_dst) + + print_on_rank0(f"Saved checkpoint to {save_dir}") + + dist.barrier() + + +def record_metrics( + args, + loss: float, + accuracy: float, + global_step: int, + tracker, + optimizer, + train_dataloader=None, + mode: str = "train", +) -> None: + logdict = {} + + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + + logdict[f"{mode}/loss"] = loss + logdict[f"{mode}/accuracy"] = accuracy + + print_on_rank0( + f"{mode.capitalize()} - Step {global_step} [{global_step}/{args.num_epochs * len(train_dataloader) // args.accumulation_steps}?], Loss: {loss:.4f}, Acc: {accuracy:.4f}" + ) + + tracker.log(logdict, step=global_step) + + +def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor: + """Shard batch data across TP ranks so each rank processes a unique portion.""" + tp_size = dist.get_world_size(get_tp_group()) + tp_rank = dist.get_rank(get_tp_group()) + return tensor.chunk(tp_size, dim=0)[tp_rank] + + +def main(): + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logging.getLogger().setLevel(logging.INFO) + warnings.filterwarnings( + "ignore", + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", + ) + + args = parse_args() + set_seed(args.seed) + + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_with_rank("Initialized distributed") + + target_model, draft_model = build_models(args) + + draft_model_last_checkpoint = None + # ── Memory checkpoint 1: right after models are on GPU ── + log_cuda_memory("checkpoint-1: after build_models") + if args.ckpt_dir is not None: + if os.path.isdir(args.ckpt_dir): + draft_model_last_checkpoint = args.ckpt_dir + print_on_rank0(f"Using checkpoint: {draft_model_last_checkpoint}") + else: + raise ValueError( + f"Provided ckpt dir {args.ckpt_dir} is not a valid directory." + ) + + if args.resume and os.path.isdir(args.output_dir): + draft_model_last_checkpoint = get_last_checkpoint( + args.output_dir, prefix=r"epoch_\d+_step" + ) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + resume_state = None + if draft_model_last_checkpoint: + loaded_model = DFlashDraftModel.from_pretrained( + draft_model_last_checkpoint, torch_dtype=torch.bfloat16 + ) + draft_model.load_state_dict(loaded_model.state_dict()) + del loaded_model + print_on_rank0("Loaded draft model weights from checkpoint") + + training_state_path = os.path.join( + draft_model_last_checkpoint, "training_state.pt" + ) + if os.path.exists(training_state_path): + resume_state = torch.load( + training_state_path, map_location="cpu", weights_only=False + ) + print_on_rank0( + f"Will resume from epoch {resume_state['epoch']}, " + f"step {resume_state['global_step']}" + ) + + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + + if args.mask_token_id is not None: + mask_token_id = args.mask_token_id + elif tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + else: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = tokenizer.mask_token_id + print_on_rank0(f"Using mask_token_id: {mask_token_id}") + + draft_model.mask_token_id = mask_token_id + draft_model.config.dflash_config["mask_token_id"] = mask_token_id + draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids + if args.gradient_checkpointing: + draft_model.gradient_checkpointing_enable() + print_on_rank0("Gradient checkpointing enabled for draft model") + print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}") + + train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + + steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) + total_steps = args.num_epochs * steps_per_epoch + print_on_rank0(f"Total training steps: {total_steps}") + + print_on_rank0("Loading target embeddings and head...") + target_components = TargetEmbeddingsAndHead.from_pretrained( + args.target_model_path, + embed_key="model.embed_tokens.weight", # Adjust if Qwen/Llama differs + lm_head_key="lm_head.weight", + device="cuda", + trust_remote_code=args.trust_remote_code, + ) + + # ── Memory checkpoint 2: after loading embed + lm_head ── + log_cuda_memory("checkpoint-2: after TargetEmbeddingsAndHead") + log_model_memory("embed_tokens", target_components.embed_tokens) + log_model_memory("lm_head", target_components.lm_head) + + dflash_model = OnlineDFlashModel( + draft_model=draft_model, + target_lm_head=target_components.lm_head, + target_embed_tokens=target_components.embed_tokens, + block_size=draft_model.block_size, + mask_token_id=mask_token_id, + attention_backend=args.attention_backend, + random_anchor=args.random_anchor, + num_anchors=args.num_anchors, + loss_decay_gamma=args.loss_decay_gamma, + lm_head_chunk_size=args.lm_head_chunk_size, + ) + + dflash_model = FSDP( + dflash_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + ) + print_with_rank("Initialized FSDP") + + # ── Memory checkpoint 3: after FSDP wrapping ── + log_cuda_memory("checkpoint-3: after FSDP wrap") + + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=total_steps, + use_fp32_params=not args.no_fp32_params, + optimizer_type=args.optimizer_type, + optimizer_config=args.optimizer_config, + ) + + # ── Memory checkpoint 4: after optimizer init ── + log_cuda_memory("checkpoint-4: after optimizer init") + log_optimizer_memory("BF16Optimizer", optimizer) + + start_epoch = 0 + global_step = 0 + if resume_state is not None: + optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) + start_epoch = resume_state["epoch"] + global_step = resume_state["global_step"] + del resume_state + print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}") + + skip_steps = global_step - start_epoch * len(train_dataloader) + + print_on_rank0(f"Initializing tracker (report_to={args.report_to})...") + tracker = create_tracker(args, args.output_dir) + print_on_rank0("Tracker initialized successfully.") + + last_time = time.time() + print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") + + for epoch in range(start_epoch, args.num_epochs): + train_dataloader.sampler.set_epoch(epoch) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm( + train_dataloader, desc=f"Training Epoch {epoch}", leave=True + ) + else: + progress_bar = train_dataloader + + for step_in_epoch, data in enumerate(progress_bar): + if epoch == start_epoch and step_in_epoch < skip_steps: + continue + global_step += 1 + + # ── Memory checkpoint 5: start of step (only first step) ── + _is_first_step = (global_step == (start_epoch * len(train_dataloader) + skip_steps + 1)) + if _is_first_step: + reset_peak_memory() + log_cuda_memory("step-start (first step)") + + input_ids = data["input_ids"].cuda() + attention_mask = data["attention_mask"].cuda() + loss_mask = data["loss_mask"].cuda() + + if _is_first_step: + log_tensor_memory("input_ids", input_ids) + log_tensor_memory("attention_mask", attention_mask) + log_tensor_memory("loss_mask", loss_mask) + log_cuda_memory("after data-to-GPU") + + target_output = target_model.generate_dflash_data( + input_ids, attention_mask, loss_mask + ) + hidden_states = target_output.hidden_states.cuda().clone() # Ensure on GPU + + if _is_first_step: + log_tensor_memory("hidden_states", hidden_states) + log_cuda_memory("after target_model.generate_dflash_data") + + loss, accuracy = dflash_model( + input_ids=input_ids, + attention_mask=attention_mask, + hidden_states=hidden_states, + loss_mask=loss_mask, + ) + + if _is_first_step: + log_cuda_memory("after dflash_model forward") + + (loss / args.accumulation_steps).backward() + + if _is_first_step: + log_cuda_memory("after backward") + log_model_memory("draft_model (with grads)", draft_model) + + if global_step % args.accumulation_steps == 0: + optimizer.step() + + if _is_first_step: + log_cuda_memory("after optimizer.step") + log_optimizer_memory("BF16Optimizer (after first step)", optimizer) + + if global_step % args.log_interval == 0: + loss_log = loss.clone() + acc_log = accuracy.clone() + dist.all_reduce(loss_log) + dist.all_reduce(acc_log) + loss_log = loss_log / dist.get_world_size() + acc_log = acc_log / dist.get_world_size() + + record_metrics( + args, + loss_log.item(), + acc_log.item(), + global_step, + tracker, + optimizer, + train_dataloader, + mode="train", + ) + + if dist.get_rank() == 0: + elapsed = time.time() - last_time + last_time = time.time() + progress_bar.set_postfix( + { + "loss": f"{loss.item():.4f}", + "acc": f"{accuracy.item():.4f}", + "iter_time": f"{elapsed:.2f}s", + } + ) + + if global_step % args.save_interval == 0: + save_checkpoint( + args, epoch, global_step, dflash_model, draft_model, optimizer + ) + + save_checkpoint( + args, args.num_epochs, global_step, dflash_model, draft_model, optimizer + ) + + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/train_dflash_lora.py b/progress/SpecForge/scripts/train_dflash_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..09bab9b6a2f0768c0ad320cfae336e7d5ef15883 --- /dev/null +++ b/progress/SpecForge/scripts/train_dflash_lora.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +# coding=utf-8 +"""DFlash LoRA Training Script. + +Trains Qwen3-8B with LoRA adapters to learn 1-step parallel block generation +(dLLM capability). No separate target model is needed for hidden state extraction — +the LoRA model uses its own representations. + +Key differences from train_dflash.py: + - No DFlashTargetModel (no hidden state extraction) + - No TargetEmbeddingsAndHead (model uses its own embed/lm_head) + - DFlashLoRADraftModel: Qwen3-8B + PEFT LoRA + - OnlineDFlashLoRAModel: full-sequence DFlash attention mask + - Only LoRA parameters are trained; base model is frozen + - Saves LoRA adapter weights only +""" + +import argparse +import json +import logging +import math +from contextlib import nullcontext +import os +import time +import warnings +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoTokenizer + +from datasets import load_dataset +from specforge.args import TrackerArgs +from specforge.core.dflash_lora import OnlineDFlashLoRAModel +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import destroy_distributed, get_dp_group, init_distributed +from specforge.modeling.draft.dflash_lora import DFlashLoRADraftModel +from specforge.optimizer import BF16Optimizer +from specforge.tracker import create_tracker +from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train DFlash LoRA (Qwen3-8B + LoRA)") + + model_group = parser.add_argument_group("model") + model_group.add_argument("--model-path", type=str, required=True, + help="Path to Qwen3-8B (or any CausalLM) base model") + model_group.add_argument("--block-size", type=int, default=16) + model_group.add_argument("--mask-token-id", type=int, default=None, + help="MASK token ID. Auto-detected from tokenizer if not set.") + model_group.add_argument("--context-len", type=int, default=0, + help="Fixed context length before blocks. 0 = treat whole seq as blocks.") + model_group.add_argument("--trust-remote-code", action="store_true") + model_group.add_argument("--attn-implementation", type=str, default="sdpa", + choices=["sdpa", "eager"], + help="Attention backend for additive mask path. " + "Ignored when --attention-backend=flex_attention.") + model_group.add_argument("--attention-backend", type=str, default="flex_attention", + choices=["flex_attention", "additive"], + help="flex_attention: use BlockMask (zero extra memory). " + "additive: use 4D additive mask with SDPA/eager.") + model_group.add_argument("--lm-head-chunk-size", type=int, default=0, + help="Chunk size for chunked cross-entropy loss. " + "0 = full logits (default). 256-512 recommended to reduce VRAM.") + model_group.add_argument("--random-anchor", action="store_true", + help="Randomly sample anchor positions each step (like non-LoRA dflash).") + model_group.add_argument("--num-anchors", type=int, default=512, + help="Max number of random anchor positions per sample (default: 512).") + + lora_group = parser.add_argument_group("lora") + lora_group.add_argument("--lora-rank", type=int, default=16) + lora_group.add_argument("--lora-alpha", type=int, default=32) + lora_group.add_argument("--lora-dropout", type=float, default=0.05) + lora_group.add_argument("--lora-target-modules", type=str, nargs="+", + default=["q_proj", "k_proj", "v_proj", "o_proj"], + help="Which modules to apply LoRA to") + lora_group.add_argument("--lora-config", type=str, default=None, + help="Path to JSON file with LoRA config (overrides individual args)") + + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="qwen") + dataset_group.add_argument("--is-preformatted", action="store_true") + dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) + dataset_group.add_argument("--build-dataset-num-proc", type=int, + default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8))) + + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=3) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=2e-4) + training_group.add_argument("--max-length", type=int, default=2048) + training_group.add_argument("--warmup-ratio", type=float, default=0.04) + training_group.add_argument("--max-grad-norm", type=float, default=1.0) + training_group.add_argument("--accumulation-steps", type=int, default=1) + training_group.add_argument("--loss-decay-gamma", type=float, default=None) + training_group.add_argument("--optimizer-type", type=str, default="adamw", + choices=["adamw", "adamw_8bit"]) + training_group.add_argument("--no-fp32-params", action="store_true") + training_group.add_argument("--gradient-checkpointing", action="store_true") + training_group.add_argument("--seed", type=int, default=42) + training_group.add_argument("--resume", action="store_true") + training_group.add_argument("--ckpt-dir", type=str, default=None) + + output_group = parser.add_argument_group("output") + output_group.add_argument("--output-dir", type=str, required=True) + output_group.add_argument("--cache-dir", type=str, default="./cache") + output_group.add_argument("--log-interval", type=int, default=50) + output_group.add_argument("--eval-interval", type=int, default=1000) + output_group.add_argument("--save-interval", type=int, default=1000) + + dist_group = parser.add_argument_group("distributed") + dist_group.add_argument("--dist-timeout", type=int, default=30) + + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + return parser.parse_args() + + +def build_model(args) -> Tuple[DFlashLoRADraftModel, OnlineDFlashLoRAModel]: + """Load Qwen3-8B, inject LoRA, wrap in OnlineDFlashLoRAModel.""" + print_on_rank0(f"Loading base model from {args.model_path}") + + # Load LoRA config from JSON if provided + lora_rank = args.lora_rank + lora_alpha = args.lora_alpha + lora_dropout = args.lora_dropout + lora_target_modules = args.lora_target_modules + + if args.lora_config is not None: + with open(args.lora_config) as f: + lora_cfg = json.load(f) + lora_rank = lora_cfg.get("lora_rank", lora_rank) + lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) + lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) + lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) + print_on_rank0(f"Loaded LoRA config from {args.lora_config}") + + # Resolve attn_implementation: flex_attention backend uses HF flex_attention impl + if args.attention_backend == "flex_attention": + attn_impl = "flex_attention" + else: + attn_impl = args.attn_implementation # sdpa or eager + + draft_model = DFlashLoRADraftModel.from_pretrained( + pretrained_model_name_or_path=args.model_path, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + lora_target_modules=lora_target_modules, + block_size=args.block_size, + mask_token_id=args.mask_token_id or 151669, # placeholder, updated after tokenizer load + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=args.trust_remote_code, + attn_implementation=attn_impl, + ) + + online_model = OnlineDFlashLoRAModel( + draft_model=draft_model, + block_size=args.block_size, + mask_token_id=args.mask_token_id or 151669, + loss_decay_gamma=args.loss_decay_gamma, + attention_backend=args.attention_backend, + lm_head_chunk_size=args.lm_head_chunk_size, + random_anchor=args.random_anchor, + num_anchors=args.num_anchors, + ) + + trainable = sum(p.numel() for p in draft_model.parameters() if p.requires_grad) + total = sum(p.numel() for p in draft_model.parameters()) + print_on_rank0(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") + + return draft_model, online_model + + +def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: + """Build train and eval dataloaders (same as train_dflash.py).""" + import hashlib + + cache_params_string = ( + f"{args.train_data_path}-{args.max_length}-{args.chat_template}-{args.model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + rank = dist.get_rank() + + # Support both jsonl and parquet/directory formats + if os.path.isdir(args.train_data_path): + train_dataset = load_dataset(args.train_data_path, split="train") + else: + train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + + dataset_kwargs = dict( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + + # Only rank 0 runs the expensive .map() preprocessing. + # Other ranks wait, then load directly from the cached result. + # This avoids N_GPU × num_proc concurrent workers causing OOM. + if rank == 0: + train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) + dist.barrier() + if rank != 0: + train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) + + min_loss_tokens = 2 * args.block_size + original_size = len(train_eagle3_dataset) + train_eagle3_dataset = train_eagle3_dataset.filter( + lambda x: x["loss_mask"].sum() >= min_loss_tokens + ) + print_on_rank0(f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples") + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=get_dp_group(), + ) + + eval_dataloader = None + if args.eval_data_path: + if os.path.isdir(args.eval_data_path): + eval_dataset = load_dataset(args.eval_data_path, split="train") + else: + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_eagle3_dataset = build_eagle3_dataset( + dataset=eval_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=get_dp_group(), + ) + + return train_dataloader, eval_dataloader + + +def save_checkpoint(args, epoch, step, online_model, draft_model, optimizer): + """Save LoRA adapter weights + training state.""" + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(online_model, StateDictType.FULL_STATE_DICT): + if dist.get_rank() == 0: + # Save LoRA adapter only + draft_model.save_pretrained(save_dir) + + torch.save( + { + "epoch": epoch, + "global_step": step, + "args": args, + **optimizer.state_dict(), + }, + os.path.join(save_dir, "training_state.pt"), + ) + print_on_rank0(f"Saved LoRA checkpoint to {save_dir}") + + dist.barrier() + + +def record_metrics(args, loss, accuracy, global_step, tracker, optimizer, + train_dataloader=None, mode="train"): + logdict = {} + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + logdict[f"{mode}/loss"] = loss + logdict[f"{mode}/accuracy"] = accuracy + print_on_rank0( + f"{mode.capitalize()} - Step {global_step}, Loss: {loss:.4f}, Acc: {accuracy:.4f}" + ) + tracker.log(logdict, step=global_step) + + +def main(): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + warnings.filterwarnings( + "ignore", + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", + ) + + args = parse_args() + set_seed(args.seed) + + # tp_size=1: LoRA training doesn't use tensor parallelism + init_distributed(timeout=args.dist_timeout, tp_size=1) + print_with_rank("Initialized distributed") + + # Load tokenizer and resolve mask_token_id + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + if args.mask_token_id is not None: + mask_token_id = args.mask_token_id + elif tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + else: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = tokenizer.mask_token_id + print_on_rank0(f"Using mask_token_id: {mask_token_id}") + args.mask_token_id = mask_token_id + + draft_model, online_model = build_model(args) + + # Update mask_token_id in models after tokenizer resolution + draft_model.mask_token_id = mask_token_id + online_model.mask_token_id = mask_token_id + + if args.gradient_checkpointing: + draft_model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + print_on_rank0("Gradient checkpointing enabled") + + # Resume from checkpoint + resume_state = None + if args.ckpt_dir is not None: + if os.path.isdir(args.ckpt_dir): + print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") + from peft import PeftModel + draft_model.model = PeftModel.from_pretrained( + draft_model.model.base_model.model, args.ckpt_dir + ) + else: + raise ValueError(f"ckpt_dir {args.ckpt_dir} is not a valid directory") + + if args.resume and os.path.isdir(args.output_dir): + last_ckpt = get_last_checkpoint(args.output_dir, prefix=r"epoch_\d+_step") + if last_ckpt: + print_on_rank0(f"Resuming from {last_ckpt}") + from peft import PeftModel + draft_model.model = PeftModel.from_pretrained( + draft_model.model.base_model.model, last_ckpt + ) + training_state_path = os.path.join(last_ckpt, "training_state.pt") + if os.path.exists(training_state_path): + resume_state = torch.load(training_state_path, map_location="cpu", weights_only=False) + print_on_rank0( + f"Will resume from epoch {resume_state['epoch']}, step {resume_state['global_step']}" + ) + + train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + + steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) + total_steps = args.num_epochs * steps_per_epoch + print_on_rank0(f"Total training steps: {total_steps}") + + # Wrap with FSDP (only LoRA params will have gradients) + online_model = FSDP( + online_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.NO_SHARD, + ) + print_with_rank("Initialized FSDP") + + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=total_steps, + use_fp32_params=not args.no_fp32_params, + optimizer_type=args.optimizer_type, + ) + + start_epoch = 0 + global_step = 0 + if resume_state is not None: + optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) + start_epoch = resume_state["epoch"] + global_step = resume_state["global_step"] + del resume_state + print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}") + + skip_steps = global_step - start_epoch * len(train_dataloader) + + tracker = create_tracker(args, args.output_dir) + last_time = time.time() + print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") + + for epoch in range(start_epoch, args.num_epochs): + train_dataloader.sampler.set_epoch(epoch) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=True) + else: + progress_bar = train_dataloader + + for step_in_epoch, data in enumerate(progress_bar): + if epoch == start_epoch and step_in_epoch < skip_steps: + continue + global_step += 1 + + input_ids = data["input_ids"].cuda() + attention_mask = data["attention_mask"].cuda() + loss_mask = data["loss_mask"].cuda() + + # Skip gradient sync during accumulation steps (only sync at optimizer step) + is_accumulation_step = (global_step % args.accumulation_steps) != 0 + ctx = online_model.no_sync() if is_accumulation_step else nullcontext() + + with ctx: + loss, accuracy = online_model( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + context_len=args.context_len, + ) + (loss / args.accumulation_steps).backward() + + if global_step % args.accumulation_steps == 0: + optimizer.step() + + if global_step % args.log_interval == 0: + loss_val = loss.item() + acc_val = accuracy.item() + loss_t = torch.tensor(loss_val, device="cuda") + acc_t = torch.tensor(acc_val, device="cuda") + dist.all_reduce(loss_t) + dist.all_reduce(acc_t) + record_metrics(args, loss_t.item() / dist.get_world_size(), + acc_t.item() / dist.get_world_size(), global_step, + tracker, optimizer, train_dataloader, mode="train") + + if dist.get_rank() == 0: + elapsed = time.time() - last_time + last_time = time.time() + progress_bar.set_postfix({ + "loss": f"{loss.item():.4f}", + "acc": f"{accuracy.item():.4f}", + "iter_time": f"{elapsed:.2f}s", + }) + + if global_step % args.save_interval == 0: + save_checkpoint(args, epoch, global_step, online_model, draft_model, optimizer) + + save_checkpoint(args, args.num_epochs, global_step, online_model, draft_model, optimizer) + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/scripts/train_eagle3.py b/progress/SpecForge/scripts/train_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..e582e250b65b75a1531c6e0156923f6feb96f2e4 --- /dev/null +++ b/progress/SpecForge/scripts/train_eagle3.py @@ -0,0 +1,1002 @@ +import argparse +import hashlib +import math +import os +import time +from argparse import ArgumentParser, Namespace +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from accelerate.utils import set_seed +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoProcessor, AutoTokenizer + +from datasets import Dataset +from specforge import ( + AutoDraftModelConfig, + AutoEagle3DraftModel, + OnlineEagle3Model, + QwenVLOnlineEagle3Model, +) +from specforge.args import SGLangBackendArgs, TrackerArgs +from specforge.data import ( + build_eagle3_dataset, + build_offline_eagle3_dataset, + generate_vocab_mapping_file, + prepare_dp_dataloaders, +) +from specforge.distributed import ( + destroy_distributed, + get_dp_group, + get_draft_dp_group, + get_tp_group, + init_distributed, +) +from specforge.modeling.target import ( + Eagle3TargetModel, + TargetHead, + get_eagle3_target_model, +) +from specforge.optimizer import BF16Optimizer +from specforge.tracker import Tracker, create_tracker, get_tracker_class +from specforge.utils import ( + create_draft_config_from_target, + get_last_checkpoint, + print_args_with_dots, + print_on_rank0, + print_with_rank, + rank_0_priority, + safe_conversations_generator, +) + + +def parse_args() -> Tuple[ArgumentParser, Namespace]: + """ + This function is used to parse the arguments for the training script. + """ + parser = argparse.ArgumentParser(description="Train Eagle3 with online data") + + # add model-related arguments + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code" + ) + model_group.add_argument( + "--draft-model-config", + type=str, + required=False, + help="Draft model config path. If not provided, will auto-generate from target model.", + ) + model_group.add_argument( + "--embedding-key", + type=str, + default="model.embed_tokens.weight", + help="The key of the embedding weight to load from the target model", + ) + model_group.add_argument( + "--lm-head-key", + type=str, + default="lm_head.weight", + help="The key of the lm head weight to load from the target model, this is only required for offline training", + ) + model_group.add_argument( + "--is-vlm", action="store_true", help="Whether the target model is a VLM" + ) + model_group.add_argument( + "--target-model-backend", + type=str, + default="sglang", + choices=["sglang", "hf", "custom"], + help="The backend of the target model", + ) + + # dataset arguments + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) + dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="llama3") + dataset_group.add_argument( + "--is-preformatted", + action="store_true", + help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", + ) + dataset_group.add_argument( + "--train-only-last-turn", + action="store_true", + help="If set, only the last assistant turn in each conversation contributes to the loss. " + "Useful for thinking models where conversation history may lack thought processes.", + ) + dataset_group.add_argument("--build-dataset-num-proc", type=int, default=8) + dataset_group.add_argument( + "--dataloader-num-workers", + type=int, + default=4, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + # training hyper params + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=10) + training_group.add_argument( + "--max-num-steps", + type=int, + default=None, + help="The maximum number of steps to train. If not provided, will be calculated as num_epochs * steps_per_epoch", + ) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=1e-4) + training_group.add_argument("--max-length", type=int, default=2048) + training_group.add_argument("--warmup-ratio", type=float, default=0.015) + training_group.add_argument( + "--total-steps", + type=int, + default=None, + help="Total training steps. If not provided, will be calculated as num_epochs * steps_per_epoch", + ) + training_group.add_argument("--max-grad-norm", type=float, default=0.5) + training_group.add_argument( + "--ttt-length", + type=int, + default=7, + help="The length for Test-Time Training (TTT).", + ) + training_group.add_argument("--resume", action="store_true") + training_group.add_argument( + "--ckpt-dir", + type=str, + default=None, + help="directory includes the checkpoint to start training with", + ) + training_group.add_argument("--eval-interval", type=int, default=5000) + training_group.add_argument("--save-interval", type=int, default=5000) + training_group.add_argument( + "--log-interval", + type=int, + default=50, + help="Log training metrics every N steps", + ) + training_group.add_argument("--seed", type=int, default=0) + training_group.add_argument("--draft-accumulation-steps", type=int, default=1) + training_group.add_argument( + "--optimizer-type", + type=str, + default="adamw", + choices=["adamw", "adamw_8bit", "apollo"], + help="Optimizer type (default: adamw)", + ) + training_group.add_argument( + "--optimizer-config", + type=str, + default=None, + help="Path to optimizer config JSON file (required for apollo)", + ) + training_group.add_argument( + "--no-fp32-params", + action="store_true", + help="Disable FP32 master copy of parameters to save memory", + ) + + # data processing type + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--tp-size", + type=int, + default=1, + help="The size of the tensor parallel for the target model", + ) + # distributed training + optimization_group.add_argument("--sp-ulysses-size", type=int, default=1) + optimization_group.add_argument("--sp-ring-size", type=int, default=1) + optimization_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + help="The attention backend for the draft model", + ) + + # other args + other_group = parser.add_argument_group("others") + other_group.add_argument("--cache-key", type=str, default=None) + other_group.add_argument("--cache-dir", type=str, default="./cache") + other_group.add_argument("--output-dir", type=str, required=True) + other_group.add_argument("--verbose", action="store_true") + other_group.add_argument( + "--dist-timeout", + type=int, + default=20, + help="Timeout for collective communication in minutes", + ) + other_group.add_argument( + "--model-download-dir", + type=str, + default=None, + help="The directory to download the target model to", + ) + + # vlm related args + vlm_group = parser.add_argument_group("vlm") + vlm_group.add_argument( + "--min-pixels", type=int, default=50176 + ) # 64*28*28 for qwen2.5-vl + vlm_group.add_argument( + "--max-pixels", type=int, default=802816 + ) # 1024*28*28 for qwen2.5-vl + + # profiling related args + profiling_group = parser.add_argument_group("profiling") + profiling_group.add_argument("--profile", action="store_true") + profiling_group.add_argument("--profile-start-step", type=int, default=30) + profiling_group.add_argument("--profile-num-steps", type=int, default=4) + profiling_group.add_argument("--profile-record-shapes", action="store_true") + + # sglang target model backend related args + sglang_group = parser.add_argument_group("sglang target model backend") + SGLangBackendArgs.add_args(sglang_group) + + # tracker related args + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + args = parser.parse_args() + return parser, args + + +def build_tracker(args: Namespace, parser: ArgumentParser) -> Tracker: + """ + Build the experiment tracker according to the report_to argument. + + Args: + args: The arguments for the training script. + parser: The parser for the training script. + + Returns: + The experiment tracker. + """ + tracker_class = get_tracker_class(args.report_to) + if tracker_class: + tracker_class.validate_args(parser, args) + else: + parser.error(f"Unknown tracker: {args.report_to}") + tracker = create_tracker(args, args.output_dir) + return tracker + + +def build_target_model( + args: Namespace, draft_model_config: AutoDraftModelConfig, is_online: bool = True +) -> Tuple[Union[Eagle3TargetModel, TargetHead], Optional[AutoProcessor]]: + """ + Build the target model according to the arguments. + + Args: + args: The arguments for the training script. + draft_model_config: The draft model config. + + Returns: + The target model. + """ + if is_online: + if ( + args.is_vlm + and draft_model_config.target_model_type == "qwen2_5_vl" + and args.target_model_backend == "custom" + ): + from transformers import Qwen2_5_VLForConditionalGeneration + + target_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path=args.target_model_path, + torch_dtype=torch.bfloat16, + ) + .eval() + .cuda() + ) + else: + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + else: + target_model_kwargs = {} + target_model = get_eagle3_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device="cuda", + cache_dir=args.model_download_dir, + **target_model_kwargs, + trust_remote_code=args.trust_remote_code, + # attn_implementation="flash_attention_2" + ) + + # set the aux hidden states layers + if ( + hasattr(draft_model_config, "eagle_config") + and draft_model_config.eagle_config is not None + and "eagle_aux_hidden_state_layer_ids" in draft_model_config.eagle_config + ): + target_model.set_aux_hidden_states_layers( + draft_model_config.eagle_config["eagle_aux_hidden_state_layer_ids"] + ) + else: + target_model.set_aux_hidden_states_layers() + + if args.is_vlm: + processor = AutoProcessor.from_pretrained( + args.target_model_path, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + ) + else: + processor = None + + return target_model, processor + else: + target_head = TargetHead.from_pretrained( + model_path=args.target_model_path, + lm_head_key=args.lm_head_key, + cache_dir=args.model_download_dir, + trust_remote_code=args.trust_remote_code, + ) + return target_head, None + + +def sanity_check(args: Namespace) -> None: + """ + Perform sanity checks on the arguments. + + Args: + args: The arguments for the training script. + + Returns: + None + """ + args.dp_size = dist.get_world_size() // args.tp_size + args.target_batch_size = args.tp_size * args.batch_size + if args.attention_backend == "usp": + sp_sanity_check(args) + + +def sp_sanity_check(args: Namespace) -> None: + args.draft_accumulation_steps = ( + args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size + ) + assert ( + args.batch_size == 1 + ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" + + assert args.sp_ring_size * args.sp_ulysses_size > 1, ( + f"USP requires sp_ring_size * sp_ulysses_size > 1. " + f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}." + ) + + assert args.train_hidden_states_path is not None, f"USP only support offline mode" + + if args.eval_data_path is not None and args.eval_hidden_states_path is not None: + raise ValueError( + "Cannot set both eval_data_path and eval_hidden_states_path. " + "For online mode, set only eval_data_path. " + "For offline mode, set only eval_hidden_states_path." + ) + + +def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]: + # Handle draft model config + if args.draft_model_config is None: + # Auto-generate and save config file + auto_config_path = create_draft_config_from_target( + target_model_path=args.target_model_path, cache_dir=args.model_download_dir + ) + draft_model_config = AutoDraftModelConfig.from_file(auto_config_path) + else: + # Use provided config file + draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) + + # Handle base ckpt, config file + draft_model_last_checkpoint = None + if args.ckpt_dir is not None: + if os.path.isdir(args.ckpt_dir): + draft_model_config = AutoDraftModelConfig.from_file( + os.path.join(args.ckpt_dir, "config.json") + ) + draft_model_last_checkpoint = args.ckpt_dir + print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}") + else: + raise ValueError( + f"Provided base model dir {args.ckpt_dir} is not a valid directory." + ) + + # detecting last ckpt for draft model + if args.resume and os.path.isdir(args.output_dir): + print_on_rank0(args.output_dir) + draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + if draft_model_last_checkpoint: + draft_model = AutoEagle3DraftModel.from_pretrained( + draft_model_last_checkpoint, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() + else: + draft_model = AutoEagle3DraftModel.from_config( + draft_model_config, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() + + draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) + draft_model.freeze_embedding() + return draft_model_config, draft_model + + +def build_dataloaders( + args: Namespace, + draft_model_config: AutoDraftModelConfig, + processor: Optional[AutoProcessor] = None, +) -> Tuple[DataLoader, str, Optional[DataLoader]]: + # build dataloaders + tokenizer = AutoTokenizer.from_pretrained( + args.target_model_path, trust_remote_code=args.trust_remote_code + ) + + # convert to dataloader + cache_params_string = ( + f"{args.train_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" # Tokenizer may also different + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + train_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.train_data_path}, + ) + is_online = ( + args.train_data_path is not None and args.train_hidden_states_path is None + ) + with rank_0_priority(): + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + is_vlm=args.is_vlm, + is_preformatted=args.is_preformatted, + processor=processor, + num_proc=args.build_dataset_num_proc, + train_only_last_turn=args.train_only_last_turn, + ) + vocab_mapping_path = generate_vocab_mapping_file( + dataset=train_eagle3_dataset, + target_vocab_size=draft_model_config.vocab_size, + draft_vocab_size=draft_model_config.draft_vocab_size, + cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), + cache_key=cache_key, + ) + + if not is_online: + train_eagle3_dataset = build_offline_eagle3_dataset( + args.train_hidden_states_path, + args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), + ) + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.target_batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=( + get_draft_dp_group() + if args.attention_backend == "usp" and not is_online + else get_dp_group() + ), + is_vlm=args.is_vlm, + ) + if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + if args.eval_data_path is not None: + eval_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.eval_data_path}, + ) + eval_eagle3_dataset = build_eagle3_dataset( + eval_dataset, + tokenizer, + args.chat_template, + args.max_length, + is_vlm=args.is_vlm, + processor=processor, + num_proc=args.build_dataset_num_proc, + is_preformatted=args.is_preformatted, + train_only_last_turn=args.train_only_last_turn, + ) + elif args.eval_hidden_states_path is not None: + eval_eagle3_dataset = build_offline_eagle3_dataset( + args.eval_hidden_states_path, + args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.target_batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=( + get_draft_dp_group() + if args.attention_backend == "usp" and not is_online + else get_dp_group() + ), + is_vlm=args.is_vlm, + ) + print_with_rank("Initialized eval dataloader") + else: + eval_dataloader = None + return ( + train_dataloader, + vocab_mapping_path, + eval_dataloader, + ) + + +def save_checkpoints( + args: Namespace, + epoch: int, + step: int, + eagle3_model: nn.Module, + optimizer: Optimizer, +): + epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(epoch_output_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT): + model_state_dict = eagle3_model.state_dict() + state_to_save = { + "epoch": epoch, + "global_step": step, + "args": args, + } + state_to_save.update(optimizer.state_dict()) + draft_model_state_dict = { + k.replace("draft_model.", ""): v + for k, v in model_state_dict.items() + if "draft_model." in k and "embed" not in k.lower() + } + + if dist.get_rank() == 0: + torch.save( + state_to_save, + os.path.join(epoch_output_dir, "training_state.pt"), + ) + print_on_rank0( + f"Saved full training state to {epoch_output_dir}/training_state.pt" + ) + eagle3_model.draft_model.save_pretrained( + epoch_output_dir, + state_dict=draft_model_state_dict, + ) + print_on_rank0(f"Saved model configuration to {epoch_output_dir}") + dist.barrier() + + +def run_forward( + args: Namespace, + eagle3_model: nn.Module, + data: dict, + target_model: Optional[Eagle3TargetModel] = None, + is_online: bool = True, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + if args.is_vlm and args.target_model_backend == "custom": + plosses, _, acces = eagle3_model( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + pixel_values=data["pixel_values"].cuda(), + image_grid_thw=data["image_grid_thw"].cuda(), + ) + else: + image_grid_thw = None + if is_online: + # we generate the eagle3 using the target model in an online fashion + # Handle VLM data: pixel_values and image_grid_thw are lists + # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None + if args.is_vlm: + image_grid_thw = ( + [thw.cuda().squeeze() for thw in data["image_grid_thw"]] + if args.is_vlm + else None + ) + pixel_values = data["pixel_values"].cuda() + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + is_vlm=args.is_vlm, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + else: + eagle3_data = target_model.generate_eagle3_data( + input_ids=data["input_ids"].cuda(), + attention_mask=data["attention_mask"].cuda(), + loss_mask=data["loss_mask"].cuda(), + ) + + input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) + attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) + loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask) + target = get_dp_data_shard_from_tp(eagle3_data.target) + hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) + else: + # we generate the logits using the hidden states loaded from disk + attention_mask = data["attention_mask"].cuda() + hidden_states = data["hidden_state"].cuda() + input_ids, target, loss_mask = target_model.preprocess( + data["input_ids"], data["target"], data["loss_mask"] + ) + input_ids = input_ids.cuda() + target = target_model( + target.cuda() + ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. + loss_mask = loss_mask.cuda() + plosses, _, acces = eagle3_model( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + target=target, + hidden_states=hidden_states, + position_ids=( + data["position_ids"].cuda() if "position_ids" in data else None + ), + image_grid_thw=image_grid_thw, + is_vlm=args.is_vlm, + ) + return plosses, acces + + +def run_backward_and_update( + args: Namespace, plosses: List[torch.Tensor], optimizer: Optimizer, global_step: int +) -> None: + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = ( + sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + / args.draft_accumulation_steps + ) + ploss.backward() + + if global_step % args.draft_accumulation_steps == 0: + optimizer.step() + + +def record_metrcs( + args: Namespace, + accuracies: List[torch.Tensor], + plosses: List[torch.Tensor], + global_step: int, + tracker: Tracker, + optimizer: Optional[Optimizer] = None, + mode: str = "train", +) -> None: + logdict = {} + + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + + accuracies = torch.stack(accuracies) + plosses = torch.stack(plosses) + + assert accuracies.shape[0] == args.ttt_length + dist.all_reduce(accuracies, op=dist.ReduceOp.AVG) + accuracies = accuracies.cpu().tolist() + for i in range(len(accuracies)): + logdict[f"{mode}/acc_{i}"] = accuracies[i] + print_on_rank0( + f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acc: {accuracies[i]:.2f}" + ) + + dist.all_reduce(plosses, op=dist.ReduceOp.AVG) + plosses = plosses.cpu().tolist() + for i in range(len(plosses)): + logdict[f"{mode}/ploss_{i}"] = plosses[i] + print_on_rank0( + f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, pLoss: {plosses[i]}" + ) + tracker.log(logdict, step=global_step) + + +def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor: + """ + Get the data shard from the tensor. + """ + tp_size = dist.get_world_size(get_tp_group()) + tp_rank = dist.get_rank(get_tp_group()) + return tensor.chunk(tp_size, dim=0)[tp_rank] + + +def main(): + # ================================================ + # 1. Initialize + # ================================================ + parser, args = parse_args() + set_seed(args.seed) + init_distributed( + timeout=args.dist_timeout, + tp_size=args.tp_size, + sp_ring_size=args.sp_ring_size, + sp_ulysses_size=args.sp_ulysses_size, + ) + is_online = ( + args.train_data_path is not None and args.train_hidden_states_path is None + ) + + sanity_check(args) + print_args_with_dots(args) + print_with_rank("Initialized distributed environment") + + # ================================================ + # 2. Build models + # ================================================ + draft_model_config, draft_model = build_draft_model(args) + target_model, processor = build_target_model(args, draft_model_config, is_online) + + # ================================================ + # 3. Build dataloader + # ================================================ + train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders( + args, draft_model_config, processor + ) + + # we load the vocab mapping then + draft_model.load_vocab_mapping(vocab_mapping_path) + print_with_rank("Loaded vocab mapping") + + # Calculate total steps if not provided + if args.total_steps is None: + steps_per_epoch = math.ceil( + len(train_dataloader) / args.draft_accumulation_steps + ) + args.total_steps = args.num_epochs * steps_per_epoch + print_with_rank( + f"Auto-calculated total_steps: {args.total_steps} (num_epochs={args.num_epochs} * steps_per_epoch={steps_per_epoch})" + ) + else: + print_with_rank(f"Using provided total_steps: {args.total_steps}") + + # ================================================ + # 4. Build Eagle3 model + # ================================================ + if ( + args.is_vlm + and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl" + and args.tp_size == 1 + and args.target_model_backend != "sglang" + ): + eagle3_model = QwenVLOnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + processor=processor, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + if is_online: + eagle3_model = OnlineEagle3Model( + target_model=target_model, + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + else: + # offline: the target_model is TargetHead not a model + eagle3_model = OnlineEagle3Model( + draft_model=draft_model, + length=args.ttt_length, + attention_backend=args.attention_backend, + ) + eagle3_model = FSDP( + eagle3_model, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + process_group=dist.group.WORLD, # the draft model should run dp for all processes + ) + # target_model.model = FSDP( + # target_model.model, + # use_orig_params=True, + # mixed_precision=MixedPrecision( + # param_dtype=torch.bfloat16, + # buffer_dtype=torch.bfloat16, + # ), + # sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + # process_group=dist.group.WORLD, # the draft model should run dp for all processes + # ) + print_with_rank("Initialized Eagle3 FSDP model") + + # ================================================ + # 5. Build optimizer and scheduler + # ================================================ + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=args.total_steps, + use_fp32_params=not args.no_fp32_params, + optimizer_type=args.optimizer_type, + optimizer_config=args.optimizer_config, + ) + print_with_rank("Initialized optimizer and scheduler") + + # ================================================ + # 6. Build tracker + # ================================================ + tracker = build_tracker(args, parser) + global_step = 0 + start_epoch = 0 + dist.barrier() + + last_time = time.time() + + # ================================================ + # 7. Start training + # ================================================ + print_on_rank0(f"Starting training from epoch {start_epoch}") + + for epoch in range(start_epoch, args.num_epochs): + # Run training + train_dataloader.sampler.set_epoch(epoch + 1) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm( + train_dataloader, desc=f"Training Epoch {epoch}", leave=True + ) + else: + progress_bar = train_dataloader + + for data in progress_bar: + global_step += 1 + + # ================================================ + # 7.0 Profiling + # ================================================ + if args.profile: + # we add the step by 1 to align with global step + if global_step == args.profile_start_step + 1: + print("Start profile") + torch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=args.profile_record_shapes, + ) + torch_profiler.start() + if global_step == args.profile_start_step + args.profile_num_steps + 1: + output_path = os.path.join( + args.output_dir, + f"profile_rank{torch.distributed.get_rank()}_{time.time()}.trace.json.gz", + ) + print(f"End profile {output_path=}") + torch_profiler.stop() + torch_profiler.export_chrome_trace(output_path) + + # ================================================ + # 7.1 Training Step + # ================================================ + plosses, acces = run_forward( + args, + eagle3_model, + data, + target_model, + is_online, + ) + run_backward_and_update(args, plosses, optimizer, global_step) + + # log training metrics + if global_step % (args.log_interval * args.draft_accumulation_steps) == 0: + record_metrcs( + args, + acces, + plosses, + global_step // args.draft_accumulation_steps, + tracker, + optimizer, + mode="train", + ) + + if dist.get_rank() == 0: + time_per_step = time.time() - last_time + last_time = time.time() + avg_loss = sum(pl for pl in plosses) / len(plosses) + avg_acc = sum(acces) / len(acces) + progress_bar.set_postfix( + { + "loss": f"{avg_loss:.2f}", + "acc": f"{avg_acc:.2f}", + "time": f"{time_per_step:.2f}s", + } + ) + + # ================================================ + # 7.2 Evaluation Step + # ================================================ + should_evaluate = ( + args.eval_data_path is not None + or args.eval_hidden_states_path is not None + ) + if ( + should_evaluate + and global_step % (args.eval_interval * args.draft_accumulation_steps) + == 0 + ): + # Run evaluation + draft_model.eval() + eval_acces = [[] for _ in range(eagle3_model.length)] + eval_plosses = [[] for _ in range(eagle3_model.length)] + + for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): + with torch.no_grad(): + plosses, acces = run_forward( + args, eagle3_model, data, target_model, is_online + ) + eval_acces = [ + eval_acces[i] + [acces[i]] for i in range(len(acces)) + ] + eval_plosses = [ + eval_plosses[i] + [plosses[i]] for i in range(len(plosses)) + ] + + # compute average over all minibatches + eval_acces = [torch.stack(acc).mean() for acc in eval_acces] + eval_plosses = [torch.stack(pl).mean() for pl in eval_plosses] + + record_metrcs( + args, + eval_acces, + eval_plosses, + global_step // args.draft_accumulation_steps, + tracker, + mode="eval", + ) + # ================================================ + # 7.3 Save Checkpoints + # ================================================ + if global_step % args.save_interval == 0: + # Save the model + save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + + if args.max_num_steps is not None and global_step >= args.max_num_steps: + break + + if args.max_num_steps is not None and global_step >= args.max_num_steps: + break + # Save final checkpoint if training ended without saving + if global_step % args.save_interval != 0: + print_on_rank0( + f"Training completed at step {global_step}, saving final checkpoint..." + ) + save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) + + # Close the tracker + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/specforge/__init__.py b/progress/SpecForge/specforge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b07280a0d9e106da207bd1b75de4a22a2de215b1 --- /dev/null +++ b/progress/SpecForge/specforge/__init__.py @@ -0,0 +1,4 @@ +from .core import * # noqa +from .modeling import * # noqa + +__all__ = ["modeling", "core"] diff --git a/progress/SpecForge/specforge/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d483b6e046f9867c0fa7d9c4d5dafdf5d883fa Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/args.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbef2e5fbb8418a387816bff6f94edf3cc0e86cb Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/args.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/distributed.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3cdd69e0d4721249532b3f10e034475be9a03a7 Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/distributed.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/lr_scheduler.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/lr_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f927416b1f5fa85878246f3e3c687ff47324c6a Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/lr_scheduler.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/optimizer.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b6dd3d2fa9fae0403a352c537376076b97d7a78 Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/optimizer.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/tracker.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/tracker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c75b9fee3c50aac0aee6473207025fb9f68b586 Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/tracker.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/__pycache__/utils.cpython-311.pyc b/progress/SpecForge/specforge/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e70e4df9c6712fc8f18fcbed94fec332442655f4 Binary files /dev/null and b/progress/SpecForge/specforge/__pycache__/utils.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/args.py b/progress/SpecForge/specforge/args.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6de14c9a473f0d992793dd76bf133eb6c371b3 --- /dev/null +++ b/progress/SpecForge/specforge/args.py @@ -0,0 +1,206 @@ +import argparse +from dataclasses import dataclass +from typing import Any, Dict, List + +from sglang.srt.server_args import ATTENTION_BACKEND_CHOICES + + +@dataclass +class TrackerArgs: + report_to: str = "none" + wandb_project: str = None + wandb_name: str = None + wandb_key: str = None + swanlab_project: str = None + swanlab_name: str = None + swanlab_key: str = None + mlflow_experiment_id: str = None + mlflow_run_name: str = None + mlflow_run_id: str = None + mlflow_tracking_uri: str = None + mlflow_registry_uri: str = None + + @staticmethod + def add_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--report-to", + type=str, + default="none", + choices=["wandb", "tensorboard", "swanlab", "mlflow", "none"], + help="The integration to report results and logs to.", + ) + # wandb-specific args + parser.add_argument("--wandb-project", type=str, default=None) + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.") + # swanlab-specific args + parser.add_argument( + "--swanlab-project", + type=str, + default=None, + help="The project name for swanlab.", + ) + parser.add_argument( + "--swanlab-name", + type=str, + default=None, + help="The experiment name for swanlab.", + ) + parser.add_argument( + "--swanlab-key", + type=str, + default=None, + help="The API key for swanlab non-interactive login.", + ) + # mlflow-specific args + parser.add_argument( + "--mlflow-tracking-uri", + type=str, + default=None, + help="The MLflow tracking URI. If not set, uses MLFLOW_TRACKING_URI environment variable or defaults to local './mlruns'.", + ) + parser.add_argument( + "--mlflow-experiment-name", + type=str, + default=None, + help="The MLflow experiment name. If not set, uses MLFLOW_EXPERIMENT_NAME environment variable.", + ) + parser.add_argument( + "--mlflow-run-name", + type=str, + default=None, + help="The MLflow run name. If not set, MLflow will auto-generate one.", + ) + + +@dataclass +class SGLangBackendArgs: + sglang_attention_backend: str = "fa3" + sglang_mem_fraction_static: float = 0.4 + sglang_context_length: int = None + sglang_enable_nccl_nvls: bool = False + sglang_enable_symm_mem: bool = False + sglang_enable_torch_compile: bool = True + sglang_enable_dp_attention: bool = False + sglang_enable_dp_lm_head: bool = False + sglang_enable_piecewise_cuda_graph: bool = False + sglang_piecewise_cuda_graph_max_tokens: int = 4096 + sglang_piecewise_cuda_graph_tokens: List[int] = None + sglang_ep_size: int = 1 + sglang_max_running_requests: int = None # assign based on batch size + sglang_max_total_tokens: int = None # assign based on batch size and seq length + + @staticmethod + def add_args(parser: argparse.ArgumentParser) -> None: + # sglang arguments + parser.add_argument( + "--sglang-attention-backend", + type=str, + default="flashinfer", + choices=ATTENTION_BACKEND_CHOICES, + help="The attention backend of SGLang backend", + ) + parser.add_argument( + "--sglang-mem-fraction-static", + type=float, + default=0.4, + help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", + ) + parser.add_argument( + "--sglang-context-length", + type=int, + default=None, + help="The context length of the SGLang backend", + ) + parser.add_argument( + "--sglang-enable-nccl-nvls", + action="store_true", + help="Enable NCCL NVLS for prefill heavy requests when available for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-symm-mem", + action="store_true", + help="Enable NCCL symmetric memory for fast collectives for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-torch-compile", + action="store_true", + help="Optimize the model with torch.compile for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-dp-attention", + action="store_true", + help="Enable DP attention for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-dp-lm-head", + action="store_true", + help="Enable piecewise CUDA graph for SGLang backend", + ) + parser.add_argument( + "--sglang-enable-piecewise-cuda-graph", + action="store_true", + help="Enable piecewise CUDA graph for SGLang backend's prefill", + ) + parser.add_argument( + "--sglang-piecewise-cuda-graph-max-tokens", + type=int, + default=4096, + help="Set the max tokens for piecewise CUDA graph for SGLang backend", + ) + parser.add_argument( + "--sglang-piecewise-cuda-graph-tokens", + type=int, + nargs="+", + default=None, + help="Set the list of tokens when using piecewise cuda graph for SGLang backend", + ) + parser.add_argument( + "--sglang-ep-size", + type=int, + default=1, + help="The ep size of the SGLang backend", + ) + + @staticmethod + def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": + return SGLangBackendArgs( + sglang_attention_backend=args.sglang_attention_backend, + sglang_mem_fraction_static=args.sglang_mem_fraction_static, + sglang_context_length=args.sglang_context_length, + sglang_enable_nccl_nvls=args.sglang_enable_nccl_nvls, + sglang_enable_symm_mem=args.sglang_enable_symm_mem, + sglang_enable_torch_compile=args.sglang_enable_torch_compile, + sglang_enable_dp_attention=args.sglang_enable_dp_attention, + sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head, + sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph, + sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens, + sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens, + sglang_ep_size=args.sglang_ep_size, + sglang_max_running_requests=( + args.target_batch_size if hasattr(args, "target_batch_size") else None + ), + sglang_max_total_tokens=( + args.target_batch_size * args.max_length + if hasattr(args, "target_batch_size") and hasattr(args, "max_length") + else None + ), + ) + + def to_kwargs(self) -> Dict[str, Any]: + return dict( + attention_backend=self.sglang_attention_backend, + mem_fraction_static=self.sglang_mem_fraction_static, + context_length=self.sglang_context_length, + enable_nccl_nvls=self.sglang_enable_nccl_nvls, + enable_symm_mem=self.sglang_enable_symm_mem, + enable_torch_compile=self.sglang_enable_torch_compile, + enable_dp_attention=self.sglang_enable_dp_attention, + enable_dp_lm_head=self.sglang_enable_dp_lm_head, + enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph, + piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens, + piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens, + ep_size=self.sglang_ep_size, + max_running_requests=self.sglang_max_running_requests, + max_total_tokens=self.sglang_max_total_tokens, + ) diff --git a/progress/SpecForge/specforge/benchmarks/benchmark_flex_attention.py b/progress/SpecForge/specforge/benchmarks/benchmark_flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..20f989565727ffe42ab112c8818ac32371b9313d --- /dev/null +++ b/progress/SpecForge/specforge/benchmarks/benchmark_flex_attention.py @@ -0,0 +1,336 @@ +import argparse +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch._dynamo as dynamo +from transformers import LlamaConfig +from transformers.cache_utils import DynamicCache + +from specforge.modeling.draft.llama3_eagle import ( + LlamaAttention, + LlamaFlexAttention, + prepare_decoder_attention_mask, +) + +dynamo.config.recompile_limit = 64 + +config_dict = { + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "max_position_embeddings": 16384, + "rms_norm_eps": 1e-05, + "vocab_size": 32000, + "hidden_act": "silu", + "num_hidden_layers": 1, +} + +config = LlamaConfig(**config_dict) + +TTT_LENGTH = 7 +BATCH_SIZE = 4 +HIDDEN_SIZE = config.hidden_size * 2 + + +def run_attention( + seq_len: int, + hidden_states_list: list[torch.Tensor], + attention_backend: str = "sdpa", + enable_profile: bool = False, +): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + batch_size = hidden_states_list[0].shape[0] + # Initialize cache and attention function based on backend + if attention_backend == "sdpa": + cache_hidden = [[], []] + past_key_values = None + attn_func = LlamaAttention(config).to(device).to(torch.bfloat16) + elif attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + attn_func = LlamaFlexAttention(config).to(device).to(torch.bfloat16) + else: + raise ValueError(f"Unknown attention backend: {attention_backend}") + + # Simulate inputs - move to device + position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(device) + input_embeds = torch.randn(batch_size, seq_len, config.hidden_size).to(device) + attention_mask = torch.ones(batch_size, seq_len).to(device) + decoder_attention_mask = prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_len), + inputs_embeds=input_embeds, + past_key_values_length=0, + ) + + loss_list = [] + + if attention_backend == "flex_attention" and enable_profile: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"./profiler_logs/{attention_backend}" + ), + record_shapes=False, + profile_memory=False, + with_stack=True, + with_modules=False, + ) + profiler.start() + for idx in range(TTT_LENGTH): + is_last = idx == TTT_LENGTH - 1 + hidden_states = hidden_states_list[idx] + # Call attention function with appropriate parameters + if attention_backend == "sdpa": + output = attn_func( + hidden_states=hidden_states, + attention_mask=decoder_attention_mask, + position_ids=position_ids, + cache_hidden=cache_hidden, + output_attentions=False, + use_cache=True, + ) + else: # flex_attention + output = attn_func( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=True, + ) + + # Compute a simple loss for benchmarking + loss = output[0].sum() + loss_list.append(loss) + + # Compute mean loss and backward pass + if loss_list: + mean_loss = sum(loss_list) / len(loss_list) + mean_loss.backward() + + if attention_backend == "flex_attention" and enable_profile: + profiler.stop() + + +def benchmark_function( + attention_backend: str, + seq_lengths: list, + enable_profile: bool = False, + enable_warmup: bool = True, +): + """Benchmark a function for speed and GPU memory usage per sequence length.""" + print(f"\n=== Benchmarking {attention_backend} ===") + + results_per_seq_len = [] + + for seq_len in seq_lengths: + print(f"\nTesting sequence length: {seq_len}") + + # Clear GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Warm up runs for this sequence length + if enable_warmup: + print("Warming up...") + for _ in range(2): + hidden_states = [ + torch.randn( + BATCH_SIZE, + seq_len, + HIDDEN_SIZE, + requires_grad=True, + device="cuda", + dtype=torch.bfloat16, + ) + for _ in range(TTT_LENGTH) + ] + run_attention(seq_len, hidden_states, attention_backend) + # Clear cache again after warmup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + # Record initial memory + initial_memory = 0 + if torch.cuda.is_available(): + initial_memory = torch.cuda.memory_allocated() + hidden_states = [ + torch.randn( + BATCH_SIZE, + seq_len, + HIDDEN_SIZE, + requires_grad=True, + device="cuda", + dtype=torch.bfloat16, + ) + for _ in range(TTT_LENGTH) + ] + start_time = time.time() + run_attention( + seq_len, + hidden_states, + attention_backend, + enable_profile and seq_len == seq_lengths[0], + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.time() + + # Record memory usage + peak_memory = 0 + current_memory = 0 + if torch.cuda.is_available(): + peak_memory = torch.cuda.max_memory_allocated() + current_memory = torch.cuda.memory_allocated() + results_per_seq_len.append( + { + "seq_len": seq_len, + "time": end_time - start_time, + "peak_memory": peak_memory, + "memory_increase": current_memory - initial_memory, + } + ) + + print(f" Time: {end_time - start_time:.3f}s") + print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") + print( + f" Memory increase: {(current_memory - initial_memory) / 1024**3:.3f} GB" + ) + + return results_per_seq_len + + +def plot_results(eagle_results, flex_results, seq_lengths): + """Plot speed and memory comparison between Eagle and Flex attention.""" + + # Extract data for plotting + eagle_times = [r["time"] for r in eagle_results] + flex_times = [r["time"] for r in flex_results] + eagle_memory = [r["peak_memory"] / 1024**3 for r in eagle_results] # Convert to GB + flex_memory = [r["peak_memory"] / 1024**3 for r in flex_results] # Convert to GB + + # Create subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + # Speed comparison plot + ax1.plot( + seq_lengths, eagle_times, "b-o", label="Eagle (SDPA)", linewidth=2, markersize=8 + ) + ax1.plot( + seq_lengths, + flex_times, + "r-s", + label="Flex Attention", + linewidth=2, + markersize=8, + ) + ax1.set_xlabel("Sequence Length") + ax1.set_ylabel("Time (seconds)") + ax1.set_title("Speed Comparison: Eagle vs Flex Attention") + ax1.legend() + ax1.grid(True, alpha=0.3) + ax1.set_xscale("linear") + ax1.set_yscale("log") + + # Memory comparison plot + ax2.plot( + seq_lengths, + eagle_memory, + "b-o", + label="Eagle (SDPA)", + linewidth=2, + markersize=8, + ) + ax2.plot( + seq_lengths, + flex_memory, + "r-s", + label="Flex Attention", + linewidth=2, + markersize=8, + ) + ax2.set_xlabel("Sequence Length") + ax2.set_ylabel("Peak Memory (GB)") + ax2.set_title("Memory Usage Comparison: Eagle vs Flex Attention") + ax2.legend() + ax2.grid(True, alpha=0.3) + + # Set y-axis ticks every 10GB + max_memory = max(max(eagle_memory), max(flex_memory)) + ax2.set_yticks(np.arange(0, max_memory + 10, 10)) + + plt.tight_layout() + plt.savefig("attention_benchmark_comparison.png", dpi=300, bbox_inches="tight") + plt.show() + + # Print summary statistics + print(f"\n=== Performance Summary ===") + print(f"Sequence lengths tested: {seq_lengths}") + print(f"\nSpeed ratios (Eagle/Flex):") + for i, seq_len in enumerate(seq_lengths): + ratio = eagle_times[i] / flex_times[i] if flex_times[i] > 0 else float("inf") + print(f" {seq_len:4d}: {ratio:.2f}x") + + print(f"\nMemory ratios (Eagle/Flex):") + for i, seq_len in enumerate(seq_lengths): + ratio = eagle_memory[i] / flex_memory[i] if flex_memory[i] > 0 else float("inf") + print(f" {seq_len:4d}: {ratio:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark attention mechanisms") + parser.add_argument( + "--enable-profile", action="store_true", help="Enable profiling" + ) + args = parser.parse_args() + + print("PyTorch version:", torch.__version__) + if torch.cuda.is_available(): + print("CUDA available:", torch.cuda.is_available()) + print("GPU:", torch.cuda.get_device_name()) + print( + "GPU memory:", + torch.cuda.get_device_properties(0).total_memory / 1024**3, + "GB", + ) + else: + print("CUDA not available - running on CPU") + + # Define sequence lengths to test + seq_lengths = [128 * i for i in range(1, 28, 4)] + # Add extra long context + seq_lengths.extend([16384, 32768]) + + print(f"Testing sequence lengths: {seq_lengths}") + + # Run benchmarks + print("\n" + "=" * 50) + # Truncate seqlen after 2560 since naive eagle goes OOM + eagle_seq_lengths = [seq_len for seq_len in seq_lengths if seq_len <= 2560] + eagle_results = benchmark_function("sdpa", eagle_seq_lengths) + print("\n" + "=" * 50) + flex_results = benchmark_function( + "flex_attention", seq_lengths, enable_profile=args.enable_profile + ) + # Pad the memory usage on eagle to max memory 80GB when data not available + max_time = max(result["time"] for result in flex_results) + for result in flex_results: + if result["seq_len"] not in eagle_seq_lengths: + eagle_results.append( + { + "seq_len": result["seq_len"], + "time": max_time, + "peak_memory": 80 * 1024**3, + "memory_increase": 0, # Not used in plotting + } + ) + + # Plot results + plot_results(eagle_results, flex_results, seq_lengths) diff --git a/progress/SpecForge/specforge/benchmarks/benchmark_loss.py b/progress/SpecForge/specforge/benchmarks/benchmark_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..940787a860d98ee406bee6a29df9127bae675d92 --- /dev/null +++ b/progress/SpecForge/specforge/benchmarks/benchmark_loss.py @@ -0,0 +1,179 @@ +import argparse +import time + +import torch + +from specforge.core.loss import LogSoftmaxLoss, _compute_loss + +TTT_LENGTH = 7 + + +def benchmark_loss_method( + loss_method: str, + test_configs: list, +): + """Benchmark a loss computation method for speed and GPU memory usage.""" + print(f"\n=== Benchmarking {loss_method} Loss ===") + + results = [] + + for config in test_configs: + B, T, V = config + print(f"\nTesting config: B={B}, T={T}, V={V}") + + # Clear GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Create tensors outside timing measurement + target = torch.softmax( + torch.randn(B, T, V, device="cuda", dtype=torch.float32), dim=-1 + ) + position_mask = torch.ones((B, T, 1), dtype=torch.bool, device="cuda") + + # Pre-allocate logits tensors for each TTT step + logits_list = [] + for i in range(TTT_LENGTH): + logits = torch.randn( + B, T, V, device="cuda", requires_grad=True, dtype=torch.float32 + ) + logits_list.append(logits) + + torch.cuda.synchronize() # Ensure all operations are complete + start_time = time.time() + + plosses = [] + for i in range(TTT_LENGTH): + logits = logits_list[i] + if loss_method == "triton": + loss = LogSoftmaxLoss.apply(logits, target, position_mask) + else: + loss = _compute_loss(logits, target, position_mask) + plosses.append(loss) + + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = ( + sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + / TTT_LENGTH + ) + ploss.backward() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.time() + total_time = end_time - start_time + # Record memory usage + peak_memory = 0 + if torch.cuda.is_available(): + peak_memory = torch.cuda.max_memory_allocated() + + results.append( + { + "B": B, + "T": T, + "V": V, + "time_total": total_time, + "peak_memory": peak_memory, + } + ) + + print(f" Total time (forward + backward): {total_time*1000:.3f}ms") + print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark loss computation methods") + parser.add_argument( + "--num-runs", type=int, default=5, help="Number of runs for averaging" + ) + args = parser.parse_args() + + print("PyTorch version:", torch.__version__) + if torch.cuda.is_available(): + print("CUDA available:", torch.cuda.is_available()) + print("GPU:", torch.cuda.get_device_name()) + print( + "GPU memory:", + torch.cuda.get_device_properties(0).total_memory / 1024**3, + "GB", + ) + else: + print("CUDA not available - running on CPU") + + # Define test configurations (B, T, V) + test_configs = [ + (1, 1024, 32000), + (1, 1024, 64000), + (1, 4096, 32000), + (1, 4096, 64000), + (1, 8192, 32000), + (1, 8192, 64000), + (1, 16384, 32000), + ] + + print(f"Testing configurations: {test_configs}") + + # Run benchmarks + print("\n" + "=" * 60) + pytorch_results = benchmark_loss_method("pytorch", test_configs) + + print("\n" + "=" * 60) + triton_results = benchmark_loss_method("triton", test_configs) + + # Print results summary + print(f"\n=== Performance Summary ===") + print(f"Configurations tested: {len(test_configs)}") + + # Print detailed results table + print( + f"\n{'Config (B,T,V)':<15} {'PyTorch (ms)':<15} {'Triton (ms)':<15} {'Speedup':<10} {'PyTorch Mem (GB)':<18} {'Triton Mem (GB)':<15} {'Memory Save':<12}" + ) + print("-" * 115) + + for i, config in enumerate(test_configs): + B, T, V = config + config_str = f"({B},{T},{V})" + + pytorch_result = next( + (r for r in pytorch_results if r["B"] == B and r["T"] == T and r["V"] == V), + None, + ) + triton_result = next( + (r for r in triton_results if r["B"] == B and r["T"] == T and r["V"] == V), + None, + ) + + if pytorch_result and triton_result: + pytorch_time_str = f"{pytorch_result['time_total']*1000:.2f}" + pytorch_mem_str = f"{pytorch_result['peak_memory']/1024**3:.2f}" + + triton_time_str = f"{triton_result['time_total']*1000:.2f}" + triton_mem_str = f"{triton_result['peak_memory']/1024**3:.2f}" + + if triton_result["time_total"] > 0: + speedup = pytorch_result["time_total"] / triton_result["time_total"] + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + + # Calculate memory savings percentage + if pytorch_result["peak_memory"] > 0: + memory_save_pct = ( + (pytorch_result["peak_memory"] - triton_result["peak_memory"]) + / pytorch_result["peak_memory"] + ) * 100 + memory_save_str = f"{memory_save_pct:.1f}%" + else: + memory_save_str = "N/A" + + print( + f"{config_str:<15} {pytorch_time_str:<15} {triton_time_str:<15} {speedup_str:<10} {pytorch_mem_str:<18} {triton_mem_str:<15} {memory_save_str:<12}" + ) + + +if __name__ == "__main__": + main() diff --git a/progress/SpecForge/specforge/core/__init__.py b/progress/SpecForge/specforge/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0ebc907b5d9a05954e1a69057c32cd070fd66e --- /dev/null +++ b/progress/SpecForge/specforge/core/__init__.py @@ -0,0 +1,9 @@ +from .dflash import OnlineDFlashModel, create_dflash_loss_mask +from .eagle3 import OnlineEagle3Model, QwenVLOnlineEagle3Model + +__all__ = [ + "OnlineDFlashModel", + "create_dflash_loss_mask", + "OnlineEagle3Model", + "QwenVLOnlineEagle3Model", +] diff --git a/progress/SpecForge/specforge/core/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d33560a7201585b6a7d155db5c52c200c084f907 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/dflash.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/dflash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd4edc241aff3b353e31ade746d744835d3433a Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/dflash.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea211e5b42d94c8a18ca83ba7b338519a5cd5092 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-313.pyc b/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ad9df8d9bc576b069956a22bb5278d6b249d61 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/dflash_lora.cpython-313.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/eagle3.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/eagle3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d6eca227261ddb9d0173ba667e5678616669fb4 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/eagle3.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/eagle3_adapters.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/eagle3_adapters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66066cefd0a949782bd81090074c1c9317c3aa85 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/eagle3_adapters.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/__pycache__/loss.cpython-311.pyc b/progress/SpecForge/specforge/core/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d6f6170c04a0c698cf4428e3e31fadf96479481 Binary files /dev/null and b/progress/SpecForge/specforge/core/__pycache__/loss.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/core/dflash.py b/progress/SpecForge/specforge/core/dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..79fa6a53c86bbded1cd77d73971da4d132456fad --- /dev/null +++ b/progress/SpecForge/specforge/core/dflash.py @@ -0,0 +1,509 @@ +# coding=utf-8 +"""DFlash Training Wrapper.""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint as grad_checkpoint + +from specforge.modeling.draft.dflash import DFlashDraftModel + +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + create_block_mask = None + + +class OnlineDFlashModel(nn.Module): + """DFlash online training wrapper with block-wise CE loss.""" + + def __init__( + self, + draft_model: DFlashDraftModel, + target_lm_head: nn.Module, + target_embed_tokens: nn.Module, + mask_token_id: int, + block_size: int = 16, + attention_backend: str = "flex_attention", + random_anchor: bool = False, + num_anchors: int = 512, + loss_decay_gamma: Optional[float] = None, + lm_head_chunk_size: int = 0, + ): + super().__init__() + self.draft_model = draft_model + self.lm_head = target_lm_head + self.embed_tokens = target_embed_tokens + self.block_size = block_size + self.mask_token_id = mask_token_id + self.attention_backend = attention_backend + self.random_anchor = random_anchor + self.num_anchors = num_anchors + self.loss_decay_gamma = loss_decay_gamma + self.lm_head_chunk_size = lm_head_chunk_size + + self._cached_block_mask: Optional[BlockMask] = None + self._cached_seq_len: Optional[int] = None + self._cached_bsz: Optional[int] = None + + def _sample_anchor_positions( + self, seq_len: int, loss_mask: torch.Tensor, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Randomly sample anchor positions per sample; returns (anchors, keep_mask).""" + bs = self.block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(self.num_anchors, int(valid_counts.max().item())) + + if max_n == 0: + anchors = torch.arange(0, seq_len, bs, device=device) + anchors = anchors.unsqueeze(0).expand(bsz, -1) + return anchors, torch.ones( + bsz, anchors.shape[1], dtype=torch.bool, device=device + ) + + anchor_list = [] + keep_list = [] + for i in range(bsz): + valid_indices = valid[i].nonzero(as_tuple=False).squeeze(-1) + n_i = min(self.num_anchors, valid_indices.numel()) + if n_i == 0: + anchors_i = torch.zeros(max_n, dtype=torch.long, device=device) + keep_i = torch.zeros(max_n, dtype=torch.bool, device=device) + else: + perm = torch.randperm(valid_indices.numel(), device=device)[:n_i] + anchors_i = valid_indices[perm].sort().values + if n_i < max_n: + anchors_i = torch.cat( + [anchors_i, anchors_i[-1:].expand(max_n - n_i)], dim=0 + ) + keep_i = torch.zeros(max_n, dtype=torch.bool, device=device) + keep_i[:n_i] = True + anchor_list.append(anchors_i) + keep_list.append(keep_i) + return torch.stack(anchor_list, dim=0), torch.stack(keep_list, dim=0) + + def _build_blocks_from_anchors( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + loss_mask: torch.Tensor, + anchor_positions: torch.Tensor, + block_keep_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Gather fixed-size blocks; padding blocks get block_id=-1 and loss=0.""" + bs = self.block_size + device = input_ids.device + bsz = input_ids.shape[0] + n = anchor_positions.shape[1] + + offsets = torch.arange(bs, device=device).unsqueeze(0) + gather_idx = anchor_positions.unsqueeze(-1) + offsets + gather_idx = gather_idx.reshape(bsz, -1) + + block_input_ids = torch.gather(input_ids, 1, gather_idx) + block_hidden = torch.gather( + hidden_states, + 1, + gather_idx.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)), + ) + block_loss_mask = torch.gather(loss_mask, 1, gather_idx) + + token_keep = block_keep_mask.repeat_interleave(bs, dim=1) + block_loss_mask = block_loss_mask * token_keep.to(block_loss_mask.dtype) + + block_ids = torch.arange(n, device=device).repeat_interleave(bs) + pad_token_mask = (~block_keep_mask).repeat_interleave(bs, dim=1) + block_ids = block_ids.unsqueeze(0).expand(bsz, -1).clone() + block_ids[pad_token_mask] = -1 + + return block_input_ids, block_hidden, block_loss_mask, block_ids, gather_idx + + def prepare_noise_input( + self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Prepare noise input: first token of each block is real, rest are MASK.""" + bsz, seq_len = input_ids.shape + device = input_ids.device + + if block_ids is not None: + is_block_start = torch.ones(bsz, seq_len, dtype=torch.bool, device=device) + is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1] + else: + positions = torch.arange(seq_len, device=device) + is_block_start = (positions % self.block_size) == 0 + is_block_start = is_block_start.unsqueeze(0).expand(bsz, -1) + + noise_input_ids = torch.full_like(input_ids, self.mask_token_id) + noise_input_ids[is_block_start] = input_ids[is_block_start] + return noise_input_ids + + def _get_or_create_block_mask( + self, + bsz: int, + q_len: int, + kv_len: int, + device: torch.device, + block_ids: Optional[torch.Tensor] = None, + ) -> "BlockMask": + """Get cached BlockMask or create a new one.""" + if block_ids is None: + if ( + self._cached_block_mask is not None + and self._cached_seq_len == q_len + and self._cached_bsz == bsz + ): + return self._cached_block_mask + + block_size = self.block_size + + if block_ids is not None: + _block_ids = block_ids + + def dflash_mask_fn(b, h, q_idx, kv_idx): + L = q_len + is_ctx = kv_idx < L + q_b = _block_ids[b, q_idx] + k_ctx = _block_ids[b, kv_idx.clamp(max=L - 1)] + k_noise = _block_ids[b, (kv_idx - L).clamp(min=0, max=L - 1)] + q_valid = q_b >= 0 + k_ctx_valid = k_ctx >= 0 + k_noise_valid = k_noise >= 0 + ctx_visible = is_ctx & q_valid & k_ctx_valid & (k_ctx < q_b) + noise_visible = (~is_ctx) & q_valid & k_noise_valid & (k_noise == q_b) + return ctx_visible | noise_visible + + else: + + def dflash_mask_fn(b, h, q_idx, kv_idx): + L = q_len + is_ctx = kv_idx < L + q_block = q_idx // block_size + k_block_ctx = kv_idx // block_size + k_block_noise = (kv_idx - L) // block_size + ctx_visible = is_ctx & (k_block_ctx < q_block) + noise_visible = (~is_ctx) & (k_block_noise == q_block) + return ctx_visible | noise_visible + + block_mask = create_block_mask( + dflash_mask_fn, + B=bsz, + H=1, + Q_LEN=q_len, + KV_LEN=kv_len, + device=device, + ) + + if block_ids is None: + self._cached_block_mask = block_mask + self._cached_seq_len = q_len + self._cached_bsz = bsz + + return block_mask + + def _create_parallel_attention_mask( + self, + bsz: int, + seq_len: int, + device: torch.device, + block_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Create [bsz, L, 2L] attention mask for parallel training.""" + if block_ids is None: + ids = torch.arange(seq_len, device=device) // self.block_size + q_ids = ids.unsqueeze(1) + k_ids = ids.unsqueeze(0) + ctx_mask = k_ids < q_ids + noise_mask = q_ids == k_ids + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + return full_mask.unsqueeze(0).expand(bsz, -1, -1) + + q_ids = block_ids.unsqueeze(2) + k_ids = block_ids.unsqueeze(1) + q_valid = q_ids >= 0 + k_valid = k_ids >= 0 + ctx_mask = q_valid & k_valid & (k_ids < q_ids) + noise_mask = q_valid & k_valid & (k_ids == q_ids) + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=2) + full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + return full_mask + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Parallel block-wise training forward pass.""" + bsz, seq_len = input_ids.shape + device = input_ids.device + block_ids = None + + if self.random_anchor and self.training: + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + (input_ids, hidden_states, loss_mask, block_ids, block_positions) = ( + self._build_blocks_from_anchors( + input_ids, + hidden_states, + loss_mask, + anchor_positions, + block_keep_mask, + ) + ) + effective_len = input_ids.shape[1] + base_positions = block_positions + else: + n_blocks = seq_len // self.block_size + effective_len = n_blocks * self.block_size + input_ids = input_ids[:, :effective_len] + hidden_states = hidden_states[:, :effective_len, :] + loss_mask = loss_mask[:, :effective_len] + attention_mask = attention_mask[:, :effective_len] + base_positions = ( + torch.arange(effective_len, device=device).unsqueeze(0).expand(bsz, -1) + ) + + noise_input_ids = self.prepare_noise_input(input_ids, block_ids) + noise_embedding = self.embed_tokens(noise_input_ids) + + position_ids = torch.cat([base_positions, base_positions], dim=1) + + if ( + self.attention_backend == "flex_attention" + and FLEX_ATTENTION_AVAILABLE + and create_block_mask is not None + ): + dflash_attn_mask = self._get_or_create_block_mask( + bsz=bsz, + q_len=effective_len, + kv_len=effective_len * 2, + device=device, + block_ids=block_ids, + ) + else: + dflash_attn_mask = self._create_parallel_attention_mask( + bsz, effective_len, device, block_ids + ) + dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype) + dflash_attn_mask = dflash_attn_mask.unsqueeze(1) + + hidden = self.draft_model( + position_ids=position_ids, + noise_embedding=noise_embedding, + target_hidden=hidden_states, + attention_mask=dflash_attn_mask, + ) + + dflash_loss_weights = create_dflash_loss_mask( + effective_len, + self.block_size, + device, + gamma=self.loss_decay_gamma, + block_ids=block_ids, + ) + if block_ids is None: + dflash_loss_weights = dflash_loss_weights.unsqueeze(0) + combined_mask = loss_mask * dflash_loss_weights + + if self.lm_head_chunk_size > 0 and effective_len > self.lm_head_chunk_size: + loss, accuracy = self._chunked_lm_loss( + hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids + ) + else: + loss, accuracy = self._full_lm_loss( + hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids + ) + + return loss, accuracy + + def _compute_acceptance_accuracy( + self, + preds_all: torch.Tensor, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + effective_len: int, + block_ids: Optional[torch.Tensor], + ) -> torch.Tensor: + """Compute block-wise acceptance rate metric.""" + bsz = input_ids.shape[0] + correct_all = (preds_all == input_ids).float() + bs = self.block_size + n_blocks = effective_len // bs + + try: + if block_ids is not None: + correct_blocks = correct_all.reshape(bsz, n_blocks, bs) + loss_mask_blocks = loss_mask.reshape(bsz, n_blocks, bs) + else: + if n_blocks > 1: + correct_blocks = correct_all[:, bs:].reshape( + bsz, n_blocks - 1, bs + ) + loss_mask_blocks = loss_mask[:, bs:].reshape( + bsz, n_blocks - 1, bs + ) + else: + raise ValueError("Only one block") + + correct_pred = correct_blocks[:, :, 1:] + loss_mask_pred = loss_mask_blocks[:, :, 1:] + + block_valid = (loss_mask_pred.sum(dim=2) == (bs - 1)).float() + correct_pred = correct_pred * loss_mask_pred + cumulative_correct = correct_pred.cumprod(dim=2) + + acceptance_lengths = cumulative_correct.sum(dim=2) + acceptance_lengths = (acceptance_lengths * block_valid).sum(dim=1) + total_blocks_sum = block_valid.sum(dim=1).sum().clamp_min(1) + avg_accept_length = acceptance_lengths.sum() / total_blocks_sum + accuracy = avg_accept_length / (bs - 1) + except Exception: + valid_mask = (loss_mask > 0.5).reshape(-1) + correct_flat = correct_all.reshape(-1)[valid_mask] + accuracy = correct_flat.mean() if correct_flat.numel() > 0 else 0.0 + + return accuracy + + def _full_lm_loss( + self, + hidden: torch.Tensor, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + combined_mask: torch.Tensor, + effective_len: int, + block_ids: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Original non-chunked lm_head + loss computation.""" + logits = self.lm_head(hidden) + + with torch.no_grad(): + preds_all = logits.argmax(dim=-1) + accuracy = self._compute_acceptance_accuracy( + preds_all, input_ids, loss_mask, effective_len, block_ids + ) + + logits_flat = logits.reshape(-1, logits.size(-1)) + labels_flat = input_ids.reshape(-1) + mask_flat = combined_mask.reshape(-1) + + active_indices = mask_flat > 1e-6 + active_logits = logits_flat[active_indices] + active_labels = labels_flat[active_indices] + active_weights = mask_flat[active_indices] + + if self.loss_decay_gamma is not None: + per_token_loss = F.cross_entropy( + active_logits, active_labels, reduction="none" + ) + loss = (per_token_loss * active_weights).sum() / active_weights.sum() + else: + loss = F.cross_entropy(active_logits, active_labels) + + return loss, accuracy + + def _chunked_lm_loss( + self, + hidden: torch.Tensor, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + combined_mask: torch.Tensor, + effective_len: int, + block_ids: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Chunked lm_head + loss: avoids materializing full [bsz, seq, vocab] logits. + + Processes the sequence in chunks of lm_head_chunk_size. Each chunk uses + gradient checkpointing so logits are recomputed (not stored) during backward. + Peak logits memory: O(chunk_size * vocab_size) instead of O(seq_len * vocab_size). + """ + chunk_size = self.lm_head_chunk_size + + # 1. Accuracy: compute argmax per chunk (no_grad, no memory concern) + with torch.no_grad(): + preds_chunks = [] + for start in range(0, effective_len, chunk_size): + end = min(start + chunk_size, effective_len) + chunk_logits = self.lm_head(hidden[:, start:end, :]) + preds_chunks.append(chunk_logits.argmax(dim=-1)) + preds_all = torch.cat(preds_chunks, dim=1) + accuracy = self._compute_acceptance_accuracy( + preds_all, input_ids, loss_mask, effective_len, block_ids + ) + + # 2. Loss: chunked with gradient checkpointing + total_loss = torch.tensor(0.0, device=hidden.device) + total_weight = torch.tensor(0.0, device=hidden.device) + + def _chunk_ce(h_chunk, labels_chunk, weights_chunk): + logits_chunk = self.lm_head(h_chunk) + logits_flat = logits_chunk.reshape(-1, logits_chunk.size(-1)) + labels_flat = labels_chunk.reshape(-1) + weights_flat = weights_chunk.reshape(-1) + active = weights_flat > 1e-6 + if not active.any(): + return logits_flat.sum() * 0, weights_flat.sum() * 0 + per_token = F.cross_entropy( + logits_flat[active], labels_flat[active], reduction="none" + ) + return (per_token * weights_flat[active]).sum(), weights_flat[active].sum() + + for start in range(0, effective_len, chunk_size): + end = min(start + chunk_size, effective_len) + chunk_loss, chunk_weight = grad_checkpoint( + _chunk_ce, + hidden[:, start:end, :], + input_ids[:, start:end], + combined_mask[:, start:end], + use_reentrant=False, + ) + total_loss = total_loss + chunk_loss + total_weight = total_weight + chunk_weight + + loss = total_loss / total_weight.clamp_min(1e-8) + return loss, accuracy + + +def create_dflash_loss_mask( + seq_len: int, + block_size: int, + device: torch.device, + gamma: Optional[float] = None, + block_ids: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Create DFlash loss mask: excludes block starts; for non-random, also excludes block 0. + + Returns [seq_len] when block_ids is None, [bsz, seq_len] when block_ids is per-sample. + """ + positions = torch.arange(seq_len, device=device) + pos_in_block = positions % block_size + + if block_ids is not None: + is_block_start = torch.ones_like(block_ids, dtype=torch.bool) + is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1] + valid_mask = ~is_block_start & (block_ids >= 0) + pos_in_block = pos_in_block.unsqueeze(0) + else: + is_block_start = (positions % block_size) == 0 + is_first_block = (positions // block_size) == 0 + valid_mask = ~is_first_block & ~is_block_start + + if gamma is not None: + decay = torch.exp(-(pos_in_block.float() - 1.0) / gamma) + return valid_mask.float() * decay + else: + return valid_mask.float() diff --git a/progress/SpecForge/specforge/core/dflash_lora.py b/progress/SpecForge/specforge/core/dflash_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0b10e32c64a225d3f10f8cb8d073e907818646e3 --- /dev/null +++ b/progress/SpecForge/specforge/core/dflash_lora.py @@ -0,0 +1,670 @@ +"""OnlineDFlashLoRAModel: training wrapper for DFlash LoRA (Qwen3-8B + LoRA).""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint as grad_checkpoint + +from specforge.modeling.draft.dflash_lora import DFlashLoRADraftModel + +try: + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + BlockMask = None + create_block_mask = None + + +class OnlineDFlashLoRAModel(nn.Module): + """ + Training wrapper for DFlash LoRA. + + Unlike OnlineDFlashModel (which uses a separate small draft model + frozen target + hidden states), this wrapper uses the full Qwen3-8B with LoRA adapters directly. + No external hidden states are needed — the model uses its own representations. + + Training objective (1-step block diffusion): + Given: context tokens x_{ Tuple[torch.Tensor, torch.Tensor]: + """Randomly sample anchor positions per sample; returns (anchors, keep_mask).""" + bs = self.block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(self.num_anchors, int(valid_counts.max().item())) + + if max_n == 0: + anchors = torch.arange(0, seq_len, bs, device=device) + anchors = anchors.unsqueeze(0).expand(bsz, -1) + return anchors, torch.ones(bsz, anchors.shape[1], dtype=torch.bool, device=device) + + anchor_list = [] + keep_list = [] + for i in range(bsz): + valid_indices = valid[i].nonzero(as_tuple=False).squeeze(-1) + n_i = min(self.num_anchors, valid_indices.numel()) + if n_i == 0: + anchors_i = torch.zeros(max_n, dtype=torch.long, device=device) + keep_i = torch.zeros(max_n, dtype=torch.bool, device=device) + else: + perm = torch.randperm(valid_indices.numel(), device=device)[:n_i] + anchors_i = valid_indices[perm].sort().values + if n_i < max_n: + anchors_i = torch.cat( + [anchors_i, anchors_i[-1:].expand(max_n - n_i)], dim=0 + ) + keep_i = torch.zeros(max_n, dtype=torch.bool, device=device) + keep_i[:n_i] = True + anchor_list.append(anchors_i) + keep_list.append(keep_i) + return torch.stack(anchor_list, dim=0), torch.stack(keep_list, dim=0) + + def _build_blocks_from_anchors( + self, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + anchor_positions: torch.Tensor, + block_keep_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Gather fixed-size blocks from random anchor positions. + + Returns: + block_input_ids: [bsz, n*bs] + block_loss_mask: [bsz, n*bs] (padding blocks zeroed out) + block_ids: [bsz, n*bs] (-1 for padding blocks) + gather_idx: [bsz, n*bs] (original sequence positions) + """ + bs = self.block_size + device = input_ids.device + bsz = input_ids.shape[0] + n = anchor_positions.shape[1] + + offsets = torch.arange(bs, device=device).unsqueeze(0) + gather_idx = anchor_positions.unsqueeze(-1) + offsets + gather_idx = gather_idx.reshape(bsz, -1) + + block_input_ids = torch.gather(input_ids, 1, gather_idx) + block_loss_mask = torch.gather(loss_mask, 1, gather_idx) + + token_keep = block_keep_mask.repeat_interleave(bs, dim=1) + block_loss_mask = block_loss_mask * token_keep.to(block_loss_mask.dtype) + + block_ids = torch.arange(n, device=device).repeat_interleave(bs) + pad_token_mask = (~block_keep_mask).repeat_interleave(bs, dim=1) + block_ids = block_ids.unsqueeze(0).expand(bsz, -1).clone() + block_ids[pad_token_mask] = -1 + + return block_input_ids, block_loss_mask, block_ids, gather_idx + + def prepare_noise_input( + self, + input_ids: torch.Tensor, + context_len: int = 0, + block_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Replace block tokens with MASK, keeping only the anchor (first token of each block). + + Two modes: + - Fixed block (block_ids=None): anchors at context_len, context_len+bs, ... + - Random anchor (block_ids provided): anchors at block boundaries in block_ids + """ + bsz, seq_len = input_ids.shape + device = input_ids.device + + if block_ids is not None: + # Random anchor mode: first token of each block (block_id transition) is anchor + is_block_start = torch.ones(bsz, seq_len, dtype=torch.bool, device=device) + is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1] + noise_input_ids = torch.full_like(input_ids, self.mask_token_id) + anchor_mask = is_block_start & (block_ids >= 0) + noise_input_ids[anchor_mask] = input_ids[anchor_mask] + return noise_input_ids + + # Fixed block mode + noise_input_ids = input_ids.clone() + block_part = noise_input_ids[:, context_len:] + block_seq_len = block_part.shape[1] + + positions = torch.arange(block_seq_len, device=device) + is_block_start = (positions % self.block_size) == 0 + mask = ~is_block_start.unsqueeze(0).expand(bsz, -1) + block_part[mask] = self.mask_token_id + noise_input_ids[:, context_len:] = block_part + + return noise_input_ids + + # ------------------------------------------------------------------ + # Attention mask builders + # ------------------------------------------------------------------ + + def _create_block_mask_random_anchor( + self, + bsz: int, + seq_len: int, + block_ids: torch.Tensor, + device: torch.device, + ) -> "BlockMask": + """Create BlockMask for random anchor: within-block bidirectional attention. + + Not cached — block_ids change every training step. + """ + _block_ids = block_ids + + def dflash_lora_random_anchor_mask_fn(b, h, q_idx, kv_idx): + q_b = _block_ids[b, q_idx] + k_b = _block_ids[b, kv_idx] + q_valid = q_b >= 0 + k_valid = k_b >= 0 + same_block = q_b == k_b + return q_valid & k_valid & same_block + + return create_block_mask( + dflash_lora_random_anchor_mask_fn, + B=bsz, + H=1, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=device, + ) + + def _build_additive_mask_random_anchor( + self, + block_ids: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Build [bsz, 1, seq_len, seq_len] additive mask for random anchor.""" + NEG_INF = torch.finfo(dtype).min + bsz, seq_len = block_ids.shape + q_ids = block_ids.unsqueeze(2) # [bsz, seq_len, 1] + k_ids = block_ids.unsqueeze(1) # [bsz, 1, seq_len] + q_valid = q_ids >= 0 + k_valid = k_ids >= 0 + visible = q_valid & k_valid & (q_ids == k_ids) + mask = torch.full((bsz, seq_len, seq_len), NEG_INF, device=device, dtype=dtype) + mask[visible] = 0.0 + return mask.unsqueeze(1) + + def _get_or_create_block_mask( + self, + bsz: int, + seq_len: int, + context_len: int, + device: torch.device, + ) -> "BlockMask": + """Get cached BlockMask or create a new one for LoRA full-sequence attention. + + Mask rules (same as build_dflash_full_attn_mask_fast, but zero extra memory): + - context token i → attends to j <= i (standard causal) + - block token i (in block b) → attends to all context + same block (bidirectional) + + Q_LEN == KV_LEN == seq_len (no context/noise KV concat like non-LoRA version). + """ + if ( + self._cached_block_mask is not None + and self._cached_seq_len == seq_len + and self._cached_context_len == context_len + and self._cached_bsz == bsz + ): + return self._cached_block_mask + + block_size = self.block_size + _context_len = context_len + + def dflash_lora_mask_fn(b, h, q_idx, kv_idx): + # Context query: standard causal + is_q_ctx = q_idx < _context_len + ctx_visible = is_q_ctx & (kv_idx <= q_idx) + + # Block query: attend to all context + same block (bidirectional) + is_q_block = q_idx >= _context_len + is_k_ctx = kv_idx < _context_len + q_block_id = (q_idx - _context_len) // block_size + k_block_id = (kv_idx - _context_len) // block_size + + block_attend_ctx = is_q_block & is_k_ctx + block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id) + block_visible = block_attend_ctx | block_attend_same + + return ctx_visible | block_visible + + block_mask = create_block_mask( + dflash_lora_mask_fn, + B=bsz, + H=1, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=device, + ) + + self._cached_block_mask = block_mask + self._cached_seq_len = seq_len + self._cached_context_len = context_len + self._cached_bsz = bsz + + return block_mask + + def build_dflash_full_attn_mask_fast( + self, + seq_len: int, + context_len: int, + bsz: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Vectorized 4D additive attention mask [bsz, 1, seq_len, seq_len]. + + Rules: + - context token i → attends to j <= i (standard causal) + - block token i (in block b) → attends to: + * all context tokens (j < context_len) + * all tokens in block b (bidirectional within block) + + Returns additive mask: 0.0 = visible, -inf = masked. + """ + NEG_INF = torch.finfo(dtype).min + positions = torch.arange(seq_len, device=device) + + # Which block does each position belong to? context tokens get block_id = -1 + block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=device) + block_seq = positions[context_len:] + block_ids[context_len:] = block_seq // self.block_size # relative block index + + q_ids = block_ids.unsqueeze(1) # [seq_len, 1] + k_ids = block_ids.unsqueeze(0) # [1, seq_len] + q_pos = positions.unsqueeze(1) # [seq_len, 1] + k_pos = positions.unsqueeze(0) # [1, seq_len] + + is_q_context = q_ids < 0 # [seq_len, 1] + is_k_context = k_ids < 0 # [1, seq_len] + + # context query: causal + ctx_q_visible = is_q_context & (k_pos <= q_pos) + + # block query: attend to context + same block + block_q_attend_ctx = (~is_q_context) & is_k_context + block_q_attend_same = (~is_q_context) & (~is_k_context) & (q_ids == k_ids) + block_q_visible = block_q_attend_ctx | block_q_attend_same + + visible = ctx_q_visible | block_q_visible # [seq_len, seq_len] + + mask = torch.full((seq_len, seq_len), NEG_INF, device=device, dtype=dtype) + mask[visible] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, -1, -1) + + def _compute_loss_weights( + self, + seq_len: int, + context_len: int, + device: torch.device, + ) -> torch.Tensor: + """ + Loss weight per position: 0 for context and block anchors, 1 (or decay) for block non-anchors. + """ + weights = torch.zeros(seq_len, device=device) + block_seq_len = seq_len - context_len + positions = torch.arange(block_seq_len, device=device) + pos_in_block = positions % self.block_size + is_anchor = pos_in_block == 0 + + if self.loss_decay_gamma is not None: + decay = torch.exp(-(pos_in_block.float() - 1.0) / self.loss_decay_gamma) + block_weights = (~is_anchor).float() * decay + else: + block_weights = (~is_anchor).float() + + weights[context_len:] = block_weights + return weights + + def _compute_loss_weights_random_anchor( + self, + block_ids: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Loss weights for random anchor mode: [bsz, n*bs]. + + 0 at anchor positions (first token of each block) and padding blocks, + 1 (or exponential decay) at other valid positions. + """ + bsz, seq_len = block_ids.shape + + # is_anchor: True at the first token of each block + is_anchor = torch.zeros(bsz, seq_len, dtype=torch.bool, device=device) + is_anchor[:, 0] = True + is_anchor[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1] + + valid = block_ids >= 0 # False for padding blocks + + if self.loss_decay_gamma is not None: + # pos_in_block is just position % block_size since blocks are contiguous + positions = torch.arange(seq_len, device=device) + pos_in_block = positions % self.block_size + decay = torch.exp(-(pos_in_block.float() - 1.0) / self.loss_decay_gamma) + return valid.float() * (~is_anchor).float() * decay.unsqueeze(0) + else: + return (valid & ~is_anchor).float() + + # ------------------------------------------------------------------ + # Loss computation + # ------------------------------------------------------------------ + + def _full_lm_loss( + self, + logits: torch.Tensor, + input_ids: torch.Tensor, + combined_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Original non-chunked logits + loss computation.""" + with torch.no_grad(): + preds = logits.argmax(dim=-1) + + logits_flat = logits.reshape(-1, logits.size(-1)) + labels_flat = input_ids.reshape(-1) + weights_flat = combined_mask.reshape(-1) + + active = weights_flat > 1e-6 + if not active.any(): + loss = logits_flat.sum() * 0.0 + elif self.loss_decay_gamma is not None: + per_token_loss = F.cross_entropy(logits_flat[active], labels_flat[active], reduction="none") + loss = (per_token_loss * weights_flat[active]).sum() / weights_flat[active].sum().clamp_min(1e-8) + else: + loss = F.cross_entropy(logits_flat[active], labels_flat[active]) + + return loss, preds + + def _chunked_lm_loss( + self, + hidden: torch.Tensor, + lm_head: nn.Module, + input_ids: torch.Tensor, + combined_mask: torch.Tensor, + effective_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Chunked lm_head + loss: avoids materializing full [bsz, seq, vocab] logits. + + Processes the sequence in chunks of lm_head_chunk_size. Each chunk uses + gradient checkpointing so logits are recomputed (not stored) during backward. + Peak logits memory: O(chunk_size * vocab_size) instead of O(seq_len * vocab_size). + """ + chunk_size = self.lm_head_chunk_size + + # 1. Accuracy: compute argmax per chunk (no_grad, no memory concern) + with torch.no_grad(): + preds_chunks = [] + for start in range(0, effective_len, chunk_size): + end = min(start + chunk_size, effective_len) + chunk_logits = lm_head(hidden[:, start:end, :]) + preds_chunks.append(chunk_logits.argmax(dim=-1)) + preds = torch.cat(preds_chunks, dim=1) + + # 2. Loss: chunked with gradient checkpointing + total_loss = torch.tensor(0.0, device=hidden.device) + total_weight = torch.tensor(0.0, device=hidden.device) + + def _chunk_ce(h_chunk, labels_chunk, weights_chunk): + logits_chunk = lm_head(h_chunk) + logits_flat = logits_chunk.reshape(-1, logits_chunk.size(-1)) + labels_flat = labels_chunk.reshape(-1) + weights_flat = weights_chunk.reshape(-1) + active = weights_flat > 1e-6 + if not active.any(): + return logits_flat.sum() * 0, weights_flat.sum() * 0 + per_token = F.cross_entropy( + logits_flat[active], labels_flat[active], reduction="none" + ) + return (per_token * weights_flat[active]).sum(), weights_flat[active].sum() + + for start in range(0, effective_len, chunk_size): + end = min(start + chunk_size, effective_len) + chunk_loss, chunk_weight = grad_checkpoint( + _chunk_ce, + hidden[:, start:end, :], + input_ids[:, start:end], + combined_mask[:, start:end], + use_reentrant=False, + ) + total_loss = total_loss + chunk_loss + total_weight = total_weight + chunk_weight + + loss = total_loss / total_weight.clamp_min(1e-8) + return loss, preds + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + context_len: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for DFlash LoRA training. + + Args: + input_ids: [bsz, seq_len] — original token ids + attention_mask: [bsz, seq_len] — padding mask (1=real, 0=pad) + loss_mask: [bsz, seq_len] — which positions to compute loss on (assistant tokens) + context_len: number of context tokens before blocks (0 = treat whole seq as blocks) + + Returns: + loss: scalar + accuracy: scalar (block-wise acceptance rate) + """ + bsz, seq_len = input_ids.shape + device = input_ids.device + + use_flex = ( + self.attention_backend == "flex_attention" + and FLEX_ATTENTION_AVAILABLE + and create_block_mask is not None + ) + + # ── Random anchor path ────────────────────────────────────────── + if self.random_anchor and self.training: + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + input_ids, loss_mask, block_ids, gather_idx = self._build_blocks_from_anchors( + input_ids, loss_mask, anchor_positions, block_keep_mask + ) + effective_len = input_ids.shape[1] + n_anchors = anchor_positions.shape[1] + + noise_input_ids = self.prepare_noise_input(input_ids, block_ids=block_ids) + position_ids = gather_idx # use original sequence positions for RoPE + + if use_flex: + dflash_mask = self._create_block_mask_random_anchor( + bsz, effective_len, block_ids, device + ) + else: + dflash_mask = self._build_additive_mask_random_anchor( + block_ids, device, torch.bfloat16 + ) + + dflash_weights = self._compute_loss_weights_random_anchor(block_ids, device) + combined_mask = loss_mask * dflash_weights + + use_chunked = self.lm_head_chunk_size > 0 and effective_len > self.lm_head_chunk_size + if use_chunked: + hidden = self.draft_model( + input_ids=noise_input_ids, + attention_mask=dflash_mask, + position_ids=position_ids, + output_hidden_states=True, + ) + lm_head = self.draft_model.get_lm_head() + loss, preds = self._chunked_lm_loss(hidden, lm_head, input_ids, combined_mask, effective_len) + else: + logits = self.draft_model( + input_ids=noise_input_ids, + attention_mask=dflash_mask, + position_ids=position_ids, + ) + loss, preds = self._full_lm_loss(logits, input_ids, combined_mask) + + with torch.no_grad(): + accuracy = self._compute_accuracy( + preds, input_ids, loss_mask, + context_len=0, effective_block_len=n_anchors * self.block_size, + ) + return loss, accuracy + + # ── Fixed block path ──────────────────────────────────────────── + # Align to block boundary + block_seq_len = seq_len - context_len + n_blocks = block_seq_len // self.block_size + effective_block_len = n_blocks * self.block_size + effective_len = context_len + effective_block_len + + input_ids = input_ids[:, :effective_len] + attention_mask = attention_mask[:, :effective_len] + loss_mask = loss_mask[:, :effective_len] + + # Prepare noisy input + noise_input_ids = self.prepare_noise_input(input_ids, context_len) + + if use_flex: + dflash_mask = self._get_or_create_block_mask( + bsz=bsz, + seq_len=effective_len, + context_len=context_len, + device=device, + ) + else: + dflash_mask = self.build_dflash_full_attn_mask_fast( + seq_len=effective_len, + context_len=context_len, + bsz=bsz, + device=device, + dtype=torch.bfloat16, + ) + + # Position ids + position_ids = torch.arange(effective_len, device=device).unsqueeze(0).expand(bsz, -1) + + # Loss weights: 0 for context + block anchors, 1 (or decay) for block non-anchors + dflash_weights = self._compute_loss_weights(effective_len, context_len, device) + combined_mask = loss_mask * dflash_weights.unsqueeze(0) + + # Forward + loss + use_chunked = self.lm_head_chunk_size > 0 and effective_len > self.lm_head_chunk_size + + if use_chunked: + # Get hidden states, then apply lm_head in chunks + hidden = self.draft_model( + input_ids=noise_input_ids, + attention_mask=dflash_mask, + position_ids=position_ids, + output_hidden_states=True, + ) + lm_head = self.draft_model.get_lm_head() + loss, preds = self._chunked_lm_loss( + hidden, lm_head, input_ids, combined_mask, effective_len, + ) + else: + # Standard path: get logits directly + logits = self.draft_model( + input_ids=noise_input_ids, + attention_mask=dflash_mask, + position_ids=position_ids, + ) + loss, preds = self._full_lm_loss(logits, input_ids, combined_mask) + + # Accuracy (no grad) + with torch.no_grad(): + accuracy = self._compute_accuracy(preds, input_ids, loss_mask, context_len, effective_block_len) + + return loss, accuracy + + def _compute_accuracy( + self, + preds: torch.Tensor, + input_ids: torch.Tensor, + loss_mask: torch.Tensor, + context_len: int, + effective_block_len: int, + ) -> torch.Tensor: + """Block-wise acceptance rate: fraction of non-anchor block positions predicted correctly.""" + bsz = input_ids.shape[0] + bs = self.block_size + n_blocks = effective_block_len // bs + + if n_blocks == 0: + return torch.tensor(0.0, device=input_ids.device) + + block_preds = preds[:, context_len:context_len + effective_block_len] + block_labels = input_ids[:, context_len:context_len + effective_block_len] + block_loss_mask = loss_mask[:, context_len:context_len + effective_block_len] + + correct = (block_preds == block_labels).float() + + try: + correct_blocks = correct.reshape(bsz, n_blocks, bs) + mask_blocks = block_loss_mask.reshape(bsz, n_blocks, bs) + + # Only non-anchor positions (pos 1..bs-1) + correct_pred = correct_blocks[:, :, 1:] + mask_pred = mask_blocks[:, :, 1:] + + block_valid = (mask_pred.sum(dim=2) == (bs - 1)).float() + cumulative = (correct_pred * mask_pred).cumprod(dim=2) + accept_len = cumulative.sum(dim=2) + accept_len = (accept_len * block_valid).sum(dim=1) + total = block_valid.sum().clamp_min(1) + accuracy = accept_len.sum() / total / (bs - 1) + except Exception: + valid = (loss_mask > 0.5).reshape(-1) + correct_flat = correct.reshape(-1) + accuracy = correct_flat[valid].mean() if valid.any() else torch.tensor(0.0, device=input_ids.device) + + return accuracy diff --git a/progress/SpecForge/specforge/core/eagle3.py b/progress/SpecForge/specforge/core/eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..1e2f04e7ea4426e1390dbad0bc90464bd6d84239 --- /dev/null +++ b/progress/SpecForge/specforge/core/eagle3.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import DynamicCache + +from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter +from specforge.core.loss import LogSoftmaxLoss +from specforge.modeling.draft import Eagle3DraftModel +from specforge.utils import padding + + +class Eagle3Model(nn.Module): + pass + + +class OnlineEagle3Model(Eagle3Model): + """ + In sgl-spec, we implement offline/online training. + Online training means we have the target hidden_states available during training. + Eagle3 using test time training technique (TTT) to train the draft model. + 1. We first extract the hidden states from the target model. + 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). + 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) + 4. We concat the projected hidden states and embedding output as the input for the draft model. + 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) + """ + + def __init__( + self, + draft_model: Eagle3DraftModel, + length: int = 7, + attention_backend="sdpa", + target_model: Optional[Eagle3Model] = None, + ): + """ + Args: + target_model: the target model to extract hidden states. + draft_model: the draft model to be trained. + length: TTT length, it means how many turns to unroll during TTT. + """ + super().__init__() + self.draft_model = draft_model + self.length = length + self.attention_backend = attention_backend + self.target_model = target_model + + def _make_adapter(self) -> BackendAdapter: + if self.attention_backend == "usp": + return UspAdapter(self) + return SdpaLikeAdapter(self) + + def _acc_and_loss( + self, + *, + logits: torch.Tensor, + target_p: torch.Tensor, + position_mask: torch.Tensor, + loss_mask: torch.Tensor, + adapter: BackendAdapter, + ) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + local_correct = ( + (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) + ).sum() + local_denom = loss_mask.sum().clamp_min(1e-6) + local_correct, local_denom = adapter.reduce_metrics( + local_correct=local_correct, local_denom=local_denom + ) + acc = local_correct / local_denom + + loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + loss = adapter.reduce_loss(loss) + return acc, loss + + def _prepare_position_ids( + self, + position_ids: Optional[torch.Tensor], + *, + seq_length: int, + past_key_values_length: int, + device: torch.device, + is_vlm: bool, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.attention_backend == "usp": + return position_ids + if position_ids is None: + if is_vlm: + mrope_positions_ids, _ = self.target_model.get_rope_index( + input_ids=input_ids, image_grid_thw=image_grid_thw + ) + return mrope_positions_ids + return ( + torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + .unsqueeze(0) + .view(-1, seq_length) + ) + + position_ids = position_ids.long() + return position_ids.view(-1, seq_length) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, + **kwargs, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. + position_ids: (batch, seq_len) + """ + # Step 1: handle vocab size + target_p_padded, position_mask = _compute_target_p_padded( + target=target, + t2d=self.draft_model.t2d, + loss_mask=loss_mask, + length=self.length, + ) + del target + torch.cuda.empty_cache() + + # basic info + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + # Step 2: project the concatenated hidden states to the target hidden size + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Step 3: process kv cache, position ids and position ids + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + position_ids = self._prepare_position_ids( + position_ids=position_ids, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + device=hidden_states.device, + is_vlm=is_vlm, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + ) + + # Step 4: handle attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + if self.attention_backend == "sdpa": + attention_mask = self.draft_model.prepare_decoder_attention_mask( + attention_mask=attention_mask, + hidden_states=hidden_states, + batch_size=batch_size, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + ) + + # Step 5: run TTT + plosses = [] + vlosses = [] + acces = [] + adapter = self._make_adapter() + # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift + global_input_ids = input_ids + if self.attention_backend in ["sdpa", "fa", "usp"]: + cache_hidden = [[], []] + past_key_values = None + elif self.attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + else: + raise ValueError(f"Unknown attention backend: {self.attention_backend}") + + for idx in range(self.length): + state = adapter.step_view( + idx=idx, + ttt_length=self.length, + global_input_ids=global_input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + position_ids=position_ids, + hidden_states=hidden_states, + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=seq_length, + ) + is_last = idx == self.length - 1 + + # Step 5.1: embed the input ids + inputs_embeds = self.draft_model.embed_input_ids(state.input_ids) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + + # Step 5.2: run the draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=state.hidden_states, + cache_hidden=cache_hidden, + attention_mask=state.attention_mask, + position_ids=state.position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + # update hidden states for next step + hidden_states = hidden_states_out + + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) + + # Step 5.5 + 5.6: metric and loss + acc, loss = self._acc_and_loss( + logits=logits, + target_p=state.target_p, + position_mask=state.position_mask, + loss_mask=state.loss_mask, + adapter=adapter, + ) + acces.append(acc) + plosses.append(loss) + + if not is_last: + # Step 5.7: we need to update the loss mask + global_input_ids = padding(global_input_ids, left=False) + position_mask = padding(position_mask, left=False) + loss_mask = padding(loss_mask, left=False) + # Flex attention mask shirnking is handled inside attention module + return plosses, vlosses, acces + + +class QwenVLOnlineEagle3Model(Eagle3Model): + """ + In sgl-spec, we implement offline/online training. + Online training means we have the target hidden_states available during training. + Eagle3 using test time training technique (TTT) to train the draft model. + 1. We first extract the hidden states from the target model. + 2. Then concatenate the hidden states from 3 aux layers (layer 1, layer num_layers//2, layer num_layers-4). + 3. We project the concatenated hidden states to the target hidden size. from (batch, seq_len, 3*hidden_size) to (batch, seq_len, hidden_size) + 4. We concat the projected hidden states and embedding output as the input for the draft model. + 5. finally, we run TTT to train the draft model. input size is (batch, seq_len, hidden_size * 2) + """ + + def __init__( + self, + target_model, + draft_model: Eagle3DraftModel, + processor, + length: int = 7, + attention_backend: str = "sdpa", + ): + """ + Args: + target_model: the target model to extract hidden states. + draft_model: the draft model to be trained. + length: TTT length, it means how many turns to unroll during TTT. + """ + super().__init__() + self.target_model = target_model + self.draft_model = draft_model + self.processor = processor + self.length = length + self.attention_backend = attention_backend + + @torch.no_grad() + def _prepare_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L692 + Extract the hidden states from the target model outputs. + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + device: the device to run the target model, if None, use the input_ids device + pixel_values: image pixel values, used for VLM models + image_grid_thw: image grid thw, used for VLM models + + Returns: + hidden_states: (batch, seq_len, 3*hidden_size) + target: (batch, seq_len, vocab_size) + loss_mask: (batch, seq_len) + input_ids: (batch, seq_len) + """ + + if device is None: + device = input_ids.device + + # run the target model to get the hidden states + outputs = self.target_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + use_cache=False, + ) + + # extract the aux hidden states + # output_hidden_states = True will return the embedding output as well + # so we have an offset of 1 + num_hidden_states = len(outputs.hidden_states) + offset = 1 + num_layers = num_hidden_states - 1 + + # Eagle3 uses 3 aux layers from layer 1, num_layers//2, num_layers-4 + low_aux_layer = 1 + offset + mid_aux_layer = num_layers // 2 - 1 + offset + last_aux_layer = num_layers - 4 + offset + + hidden_states0 = outputs.hidden_states[low_aux_layer] + hidden_states1 = outputs.hidden_states[mid_aux_layer] + hidden_states2 = outputs.hidden_states[last_aux_layer] + + hidden_states = torch.cat( + (hidden_states0, hidden_states1, hidden_states2), dim=-1 + ) + + # apply pading + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + + if target is not None: + target = target.to(device) + loss_mask = loss_mask[..., None] + loss_mask = loss_mask.to(device) + + return hidden_states, target, loss_mask, input_ids + + @torch.no_grad() + def _get_input_embeds( + self, + input_ids: torch.Tensor, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + ) -> torch.Tensor: + # get input embeding with image + # inputs_embeds = self.target_model.model.get_input_embeddings()(input_ids) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + image_embeds = self.target_model.model.get_image_features( + pixel_values, image_grid_thw + ) + image_embeds = torch.cat(image_embeds, dim=0) + n_image_tokens = ( + input_ids == self.target_model.model.config.image_token_id + ).sum() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.target_model.model.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Online eagle model trainer, modified from: https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py#L711 + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + loss_mask: (batch, seq_len) + past_key_values: We dont use this past_key_values in eagle3, but keep it for compatibility. We control kvcache by cache_hidden. + position_ids: (batch, seq_len) + pixel_values: batch image pixel values, used for VLM models + image_grid_thw: (batch, 3), image grid thw, used for VLM models + """ + # Step 0: prepare data with the target model + hidden_states, target, loss_mask, input_ids = self._prepare_data( + input_ids, attention_mask, loss_mask, pixel_values, image_grid_thw + ) + + # Step 1: handle vocab size + target_p_padded, position_mask = _compute_target_p_padded( + target=target, + t2d=self.draft_model.t2d, + loss_mask=loss_mask, + length=self.length, + ) + del target + + # basic info + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + # Step 2: project the concatenated hidden states to the target hidden size + hidden_states = self.draft_model.project_hidden_states(hidden_states) + + # Step 3: process kv cache, position ids and position ids + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + attention_mask_tensor = ( + attention_mask + if not isinstance(attention_mask, dict) + else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal( + attention_mask_tensor[:, 0], dim1=1, dim2=2 + ) + attention_mask_tensor = ( + attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + ) + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + position_ids, rope_deltas = self.target_model.model.get_rope_index( + input_ids, + image_grid_thw, + None, + second_per_grid_ts=None, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + else: + position_ids = position_ids + + # Step 4: handle attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + if self.attention_backend == "sdpa": + attention_mask = self.draft_model.prepare_decoder_attention_mask( + attention_mask=attention_mask, + hidden_states=hidden_states, + batch_size=batch_size, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + ) + + # Step 5: run TTT + plosses = [] + vlosses = [] + acces = [] + if self.attention_backend in ["sdpa", "fa"]: + cache_hidden = [[], []] + past_key_values = None + elif self.attention_backend == "flex_attention": + cache_hidden = None + past_key_values = DynamicCache() + else: + raise ValueError(f"Unknown attention backend: {self.attention_backend}") + + for idx in range(self.length): + target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() + is_last = idx == self.length - 1 + + # Step 5.1: embed the input ids + # inputs_embeds = self._get_input_embeds(input_ids, pixel_values, image_grid_thw) + inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + + # Step 5.2: run the draft model backbone + hidden_states_out = self.draft_model.backbone( + input_embeds=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + # update hidden states for next step + hidden_states = hidden_states_out + + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) + + # Step 5.5: record metrics first as we in-place modify logits + with torch.no_grad(): + acces.append( + _compute_metric_acc( + logits=logits, + target_p=target_p, + position_mask=position_mask, + loss_mask=loss_mask, + ) + ) + + # Step 5.6: calculate loss, in-place modifies logits! + loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + plosses.append(loss) + + if not is_last: + # Step 5.7: we need to update the loss mask + input_ids = padding(input_ids, left=False) + position_mask = padding(position_mask, left=False) + loss_mask = padding(loss_mask, left=False) + # Flex attention mask shirnking is handled inside attention module + return plosses, vlosses, acces + + +def _compute_target_p_padded(target, t2d, loss_mask, length): + with torch.no_grad(): + target_p, position_mask = _compute_target_p( + target=target, + t2d=t2d, + loss_mask=loss_mask, + ) + + assert len(target_p.shape) == 3 + target_p_padded = F.pad( + target_p, + pad=(0, 0, 0, length), + mode="constant", + # For bitwise equality with previous code + value=1 / target_p.shape[-1], + ) + + return target_p_padded, position_mask + + +@torch.compile(dynamic=None) +def _compute_target_p(target, t2d, loss_mask): + target_head = target + target_max_token = target_head.argmax(-1) + target_mask = t2d[target_max_token] + target_mask = target_mask[..., None].int() + position_mask = target_mask * loss_mask + target_head = target_head[..., t2d] + target_head = target_head.float() + target_p = nn.Softmax(dim=2)(target_head) + target_p = target_p.detach() + return target_p, position_mask + + +@torch.compile(dynamic=None) +def _compute_metric_acc(logits, target_p, position_mask, loss_mask): + return ( + (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) + ).sum() / loss_mask.sum().clamp_min(1e-6) diff --git a/progress/SpecForge/specforge/core/eagle3_adapters.py b/progress/SpecForge/specforge/core/eagle3_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..555c16efc3b6b894caaa4a560456533a87f3be2e --- /dev/null +++ b/progress/SpecForge/specforge/core/eagle3_adapters.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn + +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group + + +@dataclass +class StepState: + input_ids: torch.Tensor + hidden_states: torch.Tensor + position_ids: torch.Tensor + attention_mask: torch.Tensor + target_p: torch.Tensor + position_mask: torch.Tensor + loss_mask: torch.Tensor + + +class BackendAdapter: + def __init__(self, model: "OnlineEagle3Model"): + self.m = model + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + raise NotImplementedError + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + return loss + + +class SdpaLikeAdapter(BackendAdapter): + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() + return StepState( + input_ids=global_input_ids, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + target_p=target_p, + position_mask=position_mask, + loss_mask=loss_mask, + ) + + +class UspAdapter(BackendAdapter): + def __init__(self, model: "OnlineEagle3Model"): + super().__init__(model) + self.sp_group = get_draft_sp_group() + self.sp_world_size = dist.get_world_size(self.sp_group) + self.ulysses_pg = get_sp_ulysses_group() + self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg) + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + usp_chunk_size = seq_length - ttt_length + if usp_chunk_size <= 0: + raise ValueError( + f"USP local seq_length ({seq_length}) must be larger than " + f"ttt_length ({ttt_length})" + ) + target_p = target_p_padded[:, idx : idx + usp_chunk_size, :] + return StepState( + input_ids=global_input_ids[:, :usp_chunk_size], + hidden_states=hidden_states[:, :usp_chunk_size, :], + position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree], + attention_mask=attention_mask[:, :usp_chunk_size], + target_p=target_p, + position_mask=position_mask[:, :usp_chunk_size, :], + loss_mask=loss_mask[:, :usp_chunk_size, :], + ) + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + local_correct = dist_nn.all_reduce( + local_correct, op=dist.ReduceOp.SUM, group=self.sp_group + ) + local_denom = dist_nn.all_reduce( + local_denom, op=dist.ReduceOp.SUM, group=self.sp_group + ) + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group) + loss = loss / self.sp_world_size + return loss diff --git a/progress/SpecForge/specforge/core/loss.py b/progress/SpecForge/specforge/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..30e7fba7dd49cce3706ec9c740d9e597f0eade47 --- /dev/null +++ b/progress/SpecForge/specforge/core/loss.py @@ -0,0 +1,244 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. +The idea of in-place backward pass is from Liger-Kernel. +See the original Liger-Kernel repository at https://github.com/linkedin/Liger-Kernel. +""" + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# Reference implementation +@torch.compile(dynamic=None) +def _compute_loss(logits, target_p, position_mask): + logits = logits.float() + out_logp = nn.LogSoftmax(dim=2)(logits) + plogp = target_p * out_logp + loss = -torch.sum(position_mask * plogp, 2).mean() + return loss + + +def _calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 131072 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + # AMD GPU (ROCm) + if hasattr(torch.version, "hip") and torch.version.hip is not None: + num_warps //= 2 + + return BLOCK_SIZE, num_warps + + +@triton.jit +def log_softmax_forward_kernel( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + position_mask_stride, + loss_ptr, + loss_stride, + m_ptr, + d_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id * position_mask_stride + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + return + + m = float("-inf") + d = 0.0 + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load( + logits_ptr + offsets, mask=mask, other=float("-inf") + ).cast(tl.float32) + block_max = tl.max(tl.where(mask, logits_block, float("-inf"))) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum( + tl.where(mask, tl.exp(logits_block - m_new), 0.0) + ) + m = m_new + + loss = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + # log-softmax: log(exp(x - max) / sum) = (x - max) - log(sum) + normalized_logits = logits_block - m + log_normalizer = tl.log(d) + log_softmax_logits = normalized_logits - log_normalizer + weighted_log_prob = target_block * log_softmax_logits + loss += tl.sum(tl.where(mask, weighted_log_prob, 0.0)) + + loss_ptr += program_id * loss_stride + m_ptr += program_id + d_ptr += program_id + tl.store(loss_ptr, -loss) + tl.store(m_ptr, m.to(tl.float32)) + tl.store(d_ptr, d.to(tl.float32)) + + +@triton.jit +def log_softmax_backward_kernel( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + grad_output_ptr, + scaling_factor, + m_ptr, + d_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id + + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + tl.store(logits_ptr + offsets, 0.0, mask=mask) + return + + m_ptr += program_id + d_ptr += program_id + m = tl.load(m_ptr).to(tl.float32) + d = tl.load(d_ptr).to(tl.float32) + grad_output = tl.load(grad_output_ptr).to(tl.float32) + grad_output = grad_output * scaling_factor + + # First pass: compute sum of (target * grad_output) + target_grad_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_grad_sum += tl.sum(tl.where(mask, target_block * grad_output, 0.0)) + + # Second pass: compute log-softmax gradients + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + softmax_prob = tl.exp(logits_block - m) / d + normalized_grad = softmax_prob * target_grad_sum + grad_block = -(target_block * grad_output - normalized_grad) + tl.store(logits_ptr + offsets, grad_block.to(tl.float32), mask=mask) + + +class LogSoftmaxLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, target, position_mask): + B, T, V = logits.shape + loss = torch.zeros((B * T, 1), device=logits.device) + logits_flat = logits.contiguous().view(B * T, V) + target_flat = target.contiguous().view(B * T, V) + position_mask_flat = position_mask.contiguous().view(B * T, 1).bool() + grid = (B * T,) + m = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + d = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + BLOCK_SIZE, num_warps = _calculate_settings(V) + log_softmax_forward_kernel[grid]( + logits_flat, + logits_flat.stride(0), + target_flat, + target_flat.stride(0), + position_mask_flat, + position_mask_flat.stride(0), + loss, + loss.stride(0), + m, + d, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + ctx.save_for_backward(logits.detach(), target, position_mask, m, d) + return loss.squeeze(1).mean() + + @staticmethod + def backward(ctx, grad_output): + logits, target, position_mask, m, d = ctx.saved_tensors + B, T, V = logits.shape + scaling_factor = 1.0 / (B * T) + logits = logits.contiguous().view(B * T, V) + target = target.contiguous().view(B * T, V) + position_mask = position_mask.contiguous().view(B * T, 1).bool() + grid = (B * T,) + BLOCK_SIZE, num_warps = _calculate_settings(V) + log_softmax_backward_kernel[grid]( + logits, + logits.stride(0), + target, + target.stride(0), + position_mask, + grad_output, + scaling_factor, + m, + d, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + logits = logits.view(B, T, V) + return logits, None, None, None, None + + +if __name__ == "__main__": + device = "cuda" + B, T, V = 1, 1024, 16000 + logits = torch.randn(B, T, V, device=device, requires_grad=True) + logits2 = logits.clone().detach().requires_grad_(True) + target = torch.randn(B, T, V, device=device) + position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) + position_mask = torch.ones((B, T, 1), dtype=torch.bool, device=device) + output1 = LogSoftmaxLoss.apply(logits, target, position_mask) + output2 = _compute_loss(logits2, target, position_mask) + torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) + output1.backward() + output2.backward() + torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4) diff --git a/progress/SpecForge/specforge/data/__init__.py b/progress/SpecForge/specforge/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1385db853aeca011f86d4fbb85c4d2cb426b73e --- /dev/null +++ b/progress/SpecForge/specforge/data/__init__.py @@ -0,0 +1,13 @@ +from .preprocessing import ( + build_eagle3_dataset, + build_offline_eagle3_dataset, + generate_vocab_mapping_file, +) +from .utils import prepare_dp_dataloaders + +__all__ = [ + "build_eagle3_dataset", + "build_offline_eagle3_dataset", + "generate_vocab_mapping_file", + "prepare_dp_dataloaders", +] diff --git a/progress/SpecForge/specforge/data/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e87559810ec3be16a25e4eced2826fb92652038 Binary files /dev/null and b/progress/SpecForge/specforge/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/data/__pycache__/parse.cpython-311.pyc b/progress/SpecForge/specforge/data/__pycache__/parse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d355e338a7307d6684b67a4143afa36fe9495db Binary files /dev/null and b/progress/SpecForge/specforge/data/__pycache__/parse.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/data/__pycache__/preprocessing.cpython-311.pyc b/progress/SpecForge/specforge/data/__pycache__/preprocessing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ade097819faf5b3aafa2ad0fef81b229c15add Binary files /dev/null and b/progress/SpecForge/specforge/data/__pycache__/preprocessing.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/data/__pycache__/template.cpython-311.pyc b/progress/SpecForge/specforge/data/__pycache__/template.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24d49ce17a36547f012499674bf311a5536da02e Binary files /dev/null and b/progress/SpecForge/specforge/data/__pycache__/template.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/data/__pycache__/utils.cpython-311.pyc b/progress/SpecForge/specforge/data/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..906ed30df377b030144477d75e60838a2850c0c8 Binary files /dev/null and b/progress/SpecForge/specforge/data/__pycache__/utils.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/data/parse.py b/progress/SpecForge/specforge/data/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..073e882a33016d3a76af9999809a19642fa0d6b0 --- /dev/null +++ b/progress/SpecForge/specforge/data/parse.py @@ -0,0 +1,341 @@ +import json +import re +import warnings +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + +import torch +from transformers import PreTrainedTokenizer + +from .template import ChatTemplate + +__all__ = ["GeneralParser", "HarmonyParser"] + + +class Parser(ABC): + + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + self.tokenizer = tokenizer + self.chat_template = chat_template + + @abstractmethod + def parse( + self, conversation: "Conversation", max_length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parse the conversation into a list of tensors. + + Args: + conversation: The conversation to parse. + + Returns: + A list of tensors: [input_ids, loss_mask] + """ + + +_harmony_encoding = None + + +class GeneralParser(Parser): + + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + self.system_prompt = chat_template.system_prompt + self.user_message_separator = f"{chat_template.end_of_turn_token}" + self.assistant_message_separator = f"{chat_template.assistant_header}" + self.set_assistant_pattern(chat_template) + + def apply_chat_template(self, messages, **kwargs) -> str: + conversation = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False, **kwargs + ) + return conversation + + def set_assistant_pattern(self, chat_template: ChatTemplate): + if chat_template.assistant_pattern_type == "longcat": + self.assistant_pattern = ( + re.escape(self.assistant_message_separator) + + r"([\s\S]*?(?:" + + re.escape("[Round ") + + r"\d+" + + re.escape("] USER:") + + "|$))" + ) + else: + self.assistant_pattern = ( + re.escape(self.assistant_message_separator) + + r"([\s\S]*?(?:" + + re.escape(self.chat_template.end_of_turn_token) + + "|$))" + ) + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, + ) -> Dict[str, List[torch.Tensor]]: + if not preformatted: + messages = [] + + if conversation[0]["role"] == "system": + warnings.warn( + f"The first message is from system, we will use the system prompt from the data and ignore the system prompt from the template" + ) + messages.append( + {"role": "system", "content": conversation[0]["content"]} + ) + conversation = conversation[1:] + else: + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + for j, sentence in enumerate(conversation): + role = sentence["role"] + if j == 0: + if role != "user": + warnings.warn( + f"Conversation must start with a 'user' role, but found '{role}'. Conversation truncated." + ) + break + else: + prev_role = conversation[j - 1]["role"] + if role == "tool" and prev_role not in ["assistant", "tool"]: + warnings.warn( + f"A 'tool' message must follow an 'assistant' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." + ) + break + if role == "assistant" and prev_role not in ["user", "tool"]: + warnings.warn( + f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." + ) + break + tool_calls = sentence.get("tool_calls") + if isinstance(tool_calls, str): + try: + sentence["tool_calls"] = json.loads(tool_calls) + except json.JSONDecodeError: + warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}") + break + messages.append(sentence) + + try: + conversation = self.apply_chat_template(messages, **kwargs) + except (ValueError, TypeError): + # Fallback rendering for tokenizers without built-in chat_template + warnings.warn( + "Tokenizer does not have a chat_template, using fallback rendering." + ) + parts = [] + bos_token = getattr(self.tokenizer, "bos_token", None) + user_header = self.chat_template.user_header or "" + assistant_header = self.chat_template.assistant_header or "" + end_of_turn = self.chat_template.end_of_turn_token or "" + + # Add BOS token at the start + if bos_token: + parts.append(bos_token) + + for msg in messages: + if msg["role"] == "system": + parts.append(msg["content"]) + elif msg["role"] == "user": + parts.append(f"{user_header}{msg['content']}") + elif msg["role"] == "assistant": + parts.append(f"{assistant_header}{msg['content']}{end_of_turn}") + conversation = "".join(parts) + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + + # get input_ids + encoding = self.tokenizer( + conversation, + max_length=max_length, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + loss_mask = torch.zeros(len(input_ids), dtype=torch.long) + + matches = list(re.finditer(self.assistant_pattern, conversation, re.DOTALL)) + if train_only_last_turn and matches: + matches = [matches[-1]] # Only keep the last match + + for match in matches: + content_start_char = match.start(1) + content_end_char = match.end(1) + + # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length --- + # Encode the text "assistant start", the length of which is the position of the starting token. + prefix_ids = self.tokenizer.encode( + conversation[:content_start_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, + ) + # Encodes the text "assistant end", the length of which is the position of the end token. + full_ids = self.tokenizer.encode( + conversation[:content_end_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, + ) + + start_token_idx = len(prefix_ids) + end_token_idx = len(full_ids) + + # Handling out-of-bounds errors caused by truncation + actual_start = min(start_token_idx, len(input_ids)) + actual_end = min(end_token_idx, len(input_ids)) + + if actual_start < actual_end: + loss_mask[actual_start:actual_end] = 1 + return input_ids, loss_mask + + +class HarmonyParser(Parser): + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + self.reasoning_levels = ["low", "medium", "high"] + self.default_reasoning_level = "low" + + def build_single_turn_prompt( + self, + prompt_text: str, + role: str, + content: str, + ) -> str: + """Embed user message into the required prompt template.""" + if role == "system": + prompt_text = f"<|start|>system<|message|>{content}<|end|>" + elif role == "assistant_reasoning_effort": + prompt_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-06-28\n\nReasoning: {content.lower()}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>" + elif role == "user": + prompt_text += f"<|start|>user<|message|>{content}<|end|>" + elif role == "assistant_analysis": + prompt_text += ( + f"<|start|>assistant<|channel|>analysis<|message|>{content}<|end|>" + ) + elif role == "assistant_commentary": + prompt_text += ( + f"<|start|>assistant<|channel|>commentary<|message|>{content}<|end|>" + ) + elif role == "assistant_final": + prompt_text += ( + f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>" + ) + else: + raise ValueError(f"Unknown role: {role}") + return prompt_text + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + ) -> List[torch.Tensor]: + # conversation = process_harmony_conversations(conversation) + if not preformatted: + prompt_text = "" + for j, message in enumerate(conversation): + if j == 0 and ( + message["role"] != "system" + or message["role"] != "assistant_reasoning_effort" + ): + prompt_text = self.build_single_turn_prompt( + prompt_text, + "assistant_reasoning_effort", + self.default_reasoning_level, + ) + prompt_text = self.build_single_turn_prompt( + prompt_text, message["role"], message["content"] + ) + conversation = prompt_text + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + + encoding = self.tokenizer( + conversation, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + loss_mask = torch.zeros(len(input_ids), dtype=torch.long) + + # Find spans of assistant responses using regex + # We match `<|start|>assistant` and only extract the content following it. + # This continues until `<|start|>user<|message|>` appears, or until the end of the string. + pattern = re.compile( + r"<\|start\|>assistant([\s\S]*?)(?=<\|start\|>user<\|message\|>|$)" + ) + + # Find all matching segments + matches = list(pattern.finditer(conversation)) + if train_only_last_turn and matches: + matches = [matches[-1]] # Only keep the last match + + for match in matches: + # match.start(0) is the start index of the full match (including `<|start|>assistant`) + # match.start(1) is the start index of the first capture group (excluding `<|start|>assistant`) + # match.end(1) is the end index of the content + start_char = match.start(1) + end_char = match.end(1) + + # Map character indices to token indices + for idx, (ts, te) in enumerate(offsets): + # Set mask to 1 only if the token's character range falls entirely within the "content area" + if ts >= start_char and te <= end_char: + loss_mask[idx] = 1 + + return input_ids, loss_mask + + +class ThinkingParser(GeneralParser): + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + + def apply_chat_template(self, messages, **kwargs) -> str: + if messages[-1]["role"] == "assistant": + conversation_history = self.tokenizer.apply_chat_template( + messages[:-1], + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + **kwargs, + ) + conversation = ( + conversation_history + + messages[-1]["content"] + + self.chat_template.end_of_turn_token + ) + return conversation + else: + raise Exception( + f"The last message is not assistant but {messages[-1]['role']}" + ) + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, + ) -> Dict[str, List[torch.Tensor]]: + if self.chat_template.enable_thinking: + kwargs["enable_thinking"] = True + else: + pass + return super().parse( + conversation, max_length, preformatted, train_only_last_turn, **kwargs + ) diff --git a/progress/SpecForge/specforge/data/preprocessing.py b/progress/SpecForge/specforge/data/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..46f3d8647600ee7678f2b6cd37026cebe33c1c1e --- /dev/null +++ b/progress/SpecForge/specforge/data/preprocessing.py @@ -0,0 +1,752 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import io +import os +import re +import warnings +from collections import Counter +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from tqdm import tqdm +from transformers import ImageProcessingMixin, PreTrainedTokenizer + +from datasets import Dataset as HFDataset + +from ..distributed import get_draft_sp_group, get_sp_ring_group + +try: + from qwen_vl_utils import process_vision_info + + HAS_QWEN_VL_UTILS = True +except ImportError: + HAS_QWEN_VL_UTILS = False + process_vision_info = None + + +from .parse import GeneralParser, HarmonyParser, ThinkingParser +from .template import TEMPLATE_REGISTRY, ChatTemplate + +# define a type called conversation +Conversation = List[Dict[str, str]] + + +# ============================== +# This file is for preprocessing the data +# ============================== + + +def _apply_loss_mask_from_chat_template( + text: str, + offsets: torch.Tensor, + chat_template: ChatTemplate, +) -> torch.Tensor: + """ + Apply loss mask to identify assistant response spans using chat template. + + Args: + text: The formatted conversation text. + offsets: Token offset mapping from tokenizer. + chat_template: The chat template to use for identifying assistant spans. + + Returns: + A tensor indicating which tokens should contribute to the loss (1) or not (0). + """ + loss_mask = torch.zeros(len(offsets), dtype=torch.long) + + user_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.user_header}" + ) + assistant_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" + ) + + # Find spans of assistant responses using regex + assistant_pattern = ( + re.escape(assistant_message_separator) + + r"(.*?)(?=" + + re.escape(user_message_separator) + + "|$)" + ) + + matches_found = 0 + + for match in re.finditer(assistant_pattern, text, re.DOTALL): + matches_found += 1 + # Assistant response text span (excluding assistant_header itself) + assistant_start_char = match.start(1) + assistant_end_char = match.end(1) + + # Mark tokens overlapping with assistant response + for idx, (token_start, token_end) in enumerate(offsets): + # Token is part of the assistant response span + if token_end <= assistant_start_char: + continue # token before assistant text + if token_start > assistant_end_char: + continue # token after assistant text + loss_mask[idx] = 1 + + if matches_found == 0: + print("WARNING: No assistant response spans found in the conversation text.") + + return loss_mask + + +# Copied from https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py +def preprocess_conversations( + tokenizer: PreTrainedTokenizer, + conversations: Union[List[Conversation], List[str]], + chat_template: ChatTemplate, + max_length: int = 2048, + is_preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, +) -> Dict[str, List[torch.Tensor]]: + """ + Preprocess a batch of ShareGPT style conversations or pre-formatted text. + + Args: + tokenizer: The tokenizer to use for tokenization. + conversations: A list of conversations (if is_preformatted=False) or + a list of pre-formatted text strings (if is_preformatted=True). + chat_template: The chat template to use for formatting/identifying spans. + max_length: The maximum length of the tokenized input. + is_preformatted: Whether the input is already formatted text strings. + train_only_last_turn: If True, only the last assistant turn contributes to the loss. + + Returns: + A dictionary containing: + - input_ids: List of tokenized input IDs. + - loss_mask: List of loss masks indicating which tokens should contribute to the loss. + - attention_mask: List of attention masks. + """ + + # prepare result + results = {"input_ids": [], "loss_mask": [], "attention_mask": []} + + if chat_template.parser_type == "general": + parser = GeneralParser(tokenizer, chat_template) + elif chat_template.parser_type == "thinking": + parser = ThinkingParser(tokenizer, chat_template) + elif chat_template.parser_type == "openai-harmony": + parser = HarmonyParser(tokenizer, chat_template) + else: + raise ValueError(f"Invalid parser type: {chat_template.parser_type}") + + kwargs_list = [{} for _ in range(len(conversations))] + for key, value_list in kwargs.items(): + for i, value in enumerate(value_list): + kwargs_list[i][key] = value + for source, kwargs_item in zip(conversations, kwargs_list): + if not source: + # if the source is None, skip it + continue + input_ids, loss_mask = parser.parse( + source, + max_length, + preformatted=is_preformatted, + train_only_last_turn=train_only_last_turn, + **kwargs_item, + ) + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + return results + + +def preprocess_vlm_conversations( + processor: ImageProcessingMixin, + examples: List[Conversation], + chat_template: ChatTemplate, + max_length: int = 2048, +) -> Dict[str, List[torch.Tensor]]: + """ + Preprocess a batch of ShareGPT style conversations. + + Args: + processor: The image processor to use for processing images. + examples: A list of examples, where each example is a dictionary containing: + - image: The image in the conversation. + - conversations: A list of conversations, where each conversation is a list of messages. + chat_template: The chat template to use for formatting the conversations. + max_length: The maximum length of the tokenized input. + + Returns: + A dictionary containing: + - input_ids: List of tokenized input IDs. + - loss_mask: List of loss masks indicating which tokens should contribute to the loss. + - attention_mask: List of attention masks. + - pixel_values: List of pixel values for images in the examples. + - image_grid_thw: List of image grid tensors. + """ + system_prompt = chat_template.system_prompt + + # prepare result + results = { + "input_ids": [], + "loss_mask": [], + "attention_mask": [], + "pixel_values": [], + "image_grid_thw": [], + } + + # Note: currently, we assume that each example has only one image + for i, image in enumerate(examples["image"]): + source = examples["conversations"][i] + messages = [{"role": "system", "content": system_prompt}] + if not source: + # if the source is None, skip it + continue + + if source[0]["role"] != "user": + # if the first message is not from user, skip it + source = source[1:] + + convroles = ["user", "assistant"] + for j, sentence in enumerate(source): + role = sentence["role"] + assert role == convroles[j % 2], f"unexpected role {role}" + if role == "user": + # if the message is from user and has image, process the image + messages.append( + { + "role": role, + "content": [ + { + "type": "image", + "image": image, + }, + {"type": "text", "text": sentence["content"]}, + ], + } + ) + else: + messages.append({"role": role, "content": sentence["content"]}) + + conversation = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + # get vision infor use qwen_vl_utils + if not HAS_QWEN_VL_UTILS: + raise ImportError( + "qwen_vl_utils is required for VLM preprocessing but is not installed. " + "Please install it to use VLM features." + ) + image_inputs, video_inputs = process_vision_info(messages) + assert image_inputs is not None, "image_inputs must not be None" + + encoding = processor( + text=[conversation], + images=image_inputs, + videos=video_inputs, + max_length=max_length, + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + pixel_values = encoding.pixel_values + image_grid_thw = encoding.image_grid_thw[0] + + # get conversation with image info for loss mask generation + decoded_conversation = processor.tokenizer.decode( + encoding.input_ids[0], skip_special_tokens=False + ) + + # Apply loss mask + loss_mask = _apply_loss_mask_from_chat_template( + decoded_conversation, offsets, chat_template + ) + + results["input_ids"].append(input_ids[None, :]) + results["loss_mask"].append(loss_mask[None, :]) + results["attention_mask"].append(torch.ones_like(loss_mask)[None, :]) + results["pixel_values"].append(pixel_values) + results["image_grid_thw"].append(image_grid_thw[None, :]) + return results + + +def build_eagle3_dataset( + dataset: HFDataset, + tokenizer: PreTrainedTokenizer, + chat_template: Optional[str] = None, + max_length: Optional[int] = 2048, + shuffle_seed: Optional[int] = 42, + num_proc: Optional[int] = 8, + cache_dir: Optional[str] = None, + cache_key: Optional[str] = None, + is_vlm: Optional[bool] = False, + processor: Optional[ImageProcessingMixin] = None, + is_preformatted: Optional[bool] = False, + train_only_last_turn: Optional[bool] = False, +) -> HFDataset: + """ + build eagle3 dataset + + Args: + dataset: HF dataset to process. + tokenizer: The tokenizer to use for tokenization. + chat_template: The chat template to use for formatting conversations. + This includes the system prompt and user/assistant tokens + required to delineate different parts of the conversation + for loss mask generation. + max_length: The maximum length of the tokenized input. + shuffle_seed: The seed for shuffling the dataset. + num_proc: The number of processes to use for multiprocessing. + cache_dir: The directory to use for caching the processed dataset. + cache_key: The key to use for caching the processed dataset. + is_vlm: Whether the dataset is for VLM models. + processor: The image processor to use for processing images. + is_preformatted: Whether the dataset contains preformatted text of the conversation + (e.g. includes system prompt, user and assistant start and end tokens) + and doesn't need to have the chat template applied. + Note that the chat_template still needs to be specified to determine + the assistant spans for loss mask generation. + If True, expects "text" column with ready-to-train text. + If False, expects "conversations" column with ShareGPT format. + train_only_last_turn: If True, only the last assistant turn contributes to the loss. + Useful for thinking models where history may not contain thoughts. + + Returns: + The processed HF dataset. + """ + if is_vlm: + assert processor is not None, "processor must be provided when is_vlm is True" + + # Validate chat_template requirement + if chat_template is None: + raise ValueError("chat_template must be provided for all dataset types") + + assert ( + chat_template in TEMPLATE_REGISTRY.get_all_template_names() + ), f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" + + template: ChatTemplate = TEMPLATE_REGISTRY.get(chat_template) + + dataset = dataset.shuffle(seed=shuffle_seed) + original_cols = dataset.column_names + + def preprocess_function(examples): + # Handle different dataset formats + if is_vlm: + processed = preprocess_vlm_conversations( + processor, + examples, + template, + max_length, + ) + elif is_preformatted: + # Handle pre-formatted text (should be in "text" column) + if "text" not in examples: + raise ValueError( + f"Expected 'text' column for is_preformatted=True, but found columns: {list(examples.keys())}" + ) + processed = preprocess_conversations( + tokenizer, + examples["text"], + template, + max_length, + is_preformatted=True, + train_only_last_turn=train_only_last_turn, + ) + else: + # Handle ShareGPT conversations + if "conversations" not in examples: + raise ValueError( + f"Expected 'conversations' column for is_preformatted=False, but found columns: {list(examples.keys())}" + ) + conversations = examples.pop("conversations") + if "id" in examples: + examples.pop("id") + processed = preprocess_conversations( + tokenizer, + conversations, + template, + max_length, + is_preformatted=False, + train_only_last_turn=train_only_last_turn, + **examples, + ) + + return processed + + # Process dataset only once + if cache_dir and cache_key: + load_from_cache_file = True + os.makedirs(cache_dir, exist_ok=True) + cache_file_name = os.path.join(cache_dir, f"{cache_key}.pkl") + print(f"dataset is cached at {cache_file_name}") + elif cache_dir is None and cache_key is None: + load_from_cache_file = False + cache_file_name = None + print(f"dataset is not cached") + else: + warnings.warn( + f"cache_dir and cache_key must be provided together to make caching work" + ) + + # adjust batch size based on dataset type + if is_vlm: + batch_size = ( + 200 # reduce batch size for VLM datasets to avoid PyArrow offset overflow + ) + else: + batch_size = 1000 # default for conversations + dataset = dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + batch_size=batch_size, + remove_columns=original_cols, + # keep_in_memory=True, + load_from_cache_file=load_from_cache_file, + cache_file_name=cache_file_name, + ) + + dataset.set_format(type="torch") + return dataset + + +# ============================== +# Offline Eagle3 Dataset +# ============================== +# modified from https://github.com/NickL77/BaldEagle/blob/master/train/modules/data/data.py +def list_local_files(path, suffixes=None): + if suffixes is None: + suffixes = [".ckpt", ".ckpt.gz"] + datapaths = [] + for root, directories, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + datapaths.append(file_path) + if suffixes: + datapaths = [ + f_name + for f_name in datapaths + if any(f_name.endswith(suffix) for suffix in suffixes) + ] + return datapaths + + +class OfflineEagle3Dataset(torch.utils.data.Dataset): + def __init__( + self, + datapath, + transform=None, + max_len=2048, + ttt_length=1, + use_usp_preprocess=False, + ): + """ + Args: + datapath: List of file paths. + transform: Optional transform to apply. + max_len: Maximum sequence length to load. + ttt_length: TTT overlap length used in USP preprocessing. + use_usp_preprocess: Whether to shard all sequences with USP overlap in preprocessing. + """ + self.datapaths = datapath + self.transform = transform + self._epoch = 0 + self.max_len = max_len + self.ttt_length = ttt_length + self.use_usp_preprocess = use_usp_preprocess + if use_usp_preprocess: + sp_group = get_draft_sp_group() + self.sp_rank = torch.distributed.get_rank(sp_group) + self.sp_size = torch.distributed.get_world_size(sp_group) + ring_group = get_sp_ring_group() + self.ring_rank = torch.distributed.get_rank(ring_group) + self.sp_ring_size = torch.distributed.get_world_size(ring_group) + + @staticmethod + def process_data(data, max_len, transform=None): + new_data = {} + # Squeeze due to our data generation script adding a batch dimension + hidden_state = data["aux_hidden_state"].squeeze(0)[:max_len][None, :] + target = data["hidden_state"].squeeze(0)[:max_len][None, :] + + input_ids = data["input_ids"][:max_len][None, :] + loss_mask = data["loss_mask"][:max_len][None, :] + loss_mask[0, -1] = 0 + + new_data["attention_mask"] = torch.ones_like(loss_mask, dtype=torch.long) + new_data["loss_mask"] = loss_mask + new_data["target"] = target + new_data["hidden_state"] = hidden_state + new_data["input_ids"] = input_ids + if transform: + new_data = transform(new_data) + return new_data + + @staticmethod + def process_data_usp( + data, + max_len, + ttt_length=1, + transform=None, + sp_rank=0, + sp_size=1, + ring_rank=0, + sp_ring_size=1, + ): + """ + USP preprocess: shard all sequences by sp_rank and add TTT overlap. + Each local sequence length = ceil(max_len / sp_size) + ttt_length. + """ + new_data = {} + + input_ids = data["input_ids"] + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + global_len = min(max_len, input_ids.shape[1]) + chunk_size = (global_len + sp_size - 1) // sp_size + start = sp_rank * chunk_size + local_len = chunk_size + ttt_length + + end = min(start + local_len, global_len) + + def _slice_and_pad(tensor): + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + tensor = tensor[:, :global_len] + sliced = tensor[:, start : min(end, tensor.shape[1])] + valid_len = sliced.shape[1] + if valid_len < local_len: + pad_len = local_len - valid_len + if tensor.ndim == 2: + sliced = F.pad(sliced, (0, pad_len)) + else: + sliced = F.pad(sliced, (0, 0, 0, pad_len)) + return sliced.contiguous(), valid_len + + if "aux_hidden_state" not in data or data["aux_hidden_state"] is None: + raise KeyError("aux_hidden_state is required for OfflineEagle3Dataset") + new_data["hidden_state"], _ = _slice_and_pad(data["aux_hidden_state"]) + new_data["target"], _ = _slice_and_pad(data["hidden_state"]) + + new_data["input_ids"], valid_len = _slice_and_pad(input_ids) + + full_loss_mask = data["loss_mask"] + if full_loss_mask.ndim == 1: + full_loss_mask = full_loss_mask.unsqueeze(0) + + full_loss_mask = full_loss_mask[:, :global_len].clone() + if full_loss_mask.numel() > 0: + full_loss_mask[0, -1] = 0 + new_data["loss_mask"], _ = _slice_and_pad(full_loss_mask) + + local_len = new_data["input_ids"].shape[1] + attention_mask = torch.zeros((1, local_len), dtype=torch.long) + attention_mask[:, :valid_len] = 1 + new_data["attention_mask"] = attention_mask + + # Position ids should align with Ulysses all2all-expanded sequence length. + # Local seq_len (per sp_rank) = local_len; attention uses (local_len - ttt_length). + sp_ulysses_size = max(1, sp_size // sp_ring_size) + usp_chunk_size = max(local_len - ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + ring_start = ring_rank * ring_chunk + new_data["position_ids"] = torch.arange( + ring_start, ring_start + ring_chunk, dtype=torch.long + ).unsqueeze(0) + + if transform: + new_data = transform(new_data) + + return new_data + + def __len__(self): + return len(self.datapaths) + + def _open_file(self, index): + """ + Opens the file with memory mapping. + This operation is virtually instant and consumes negligible RAM + because no data is actually read from disk yet. + """ + data_path = self.datapaths[index] + if data_path.endswith(".gz"): + with gzip.open(data_path, "rb") as f: + return torch.load(io.BytesIO(f.read()), weights_only=False) + return torch.load(data_path, weights_only=False, mmap=True) + + def __getitem__(self, index): + try: + data = self._open_file(index) + except Exception as e: + print(f"ERROR Failed to load {self.datapaths[index]} with error {e}") + data = self._open_file(0) + + # 2. Read only specific bytes from disk + if self.use_usp_preprocess: + return self.process_data_usp( + data, + self.max_len, + ttt_length=self.ttt_length, + transform=self.transform, + sp_rank=self.sp_rank, + sp_size=self.sp_size, + ring_rank=self.ring_rank, + sp_ring_size=self.sp_ring_size, + ) + return self.process_data( + data, + self.max_len, + self.transform, + ) + + def set_epoch(self, epoch): + self._epoch = epoch + + +def build_offline_eagle3_dataset( + hidden_states_path: str, + max_len: int = 2048, + ttt_length: int = 1, + use_usp_preprocess: bool = False, +) -> torch.utils.data.Dataset: + + return OfflineEagle3Dataset( + list_local_files(hidden_states_path), + max_len=max_len, + ttt_length=ttt_length, + use_usp_preprocess=use_usp_preprocess, + ) + + +# ============================== +# Vocab Mapping +# ============================== +def generate_vocab_mapping_file( + dataset: HFDataset, + target_vocab_size: int, + draft_vocab_size: int, + cache_dir: str = "./cache/vocab_mapping", + cache_key: str = "vocab_mapping", +) -> str: + """ + Generate a vocab mapping file for the dataset. + + Args: + dataset: The dataset to process. + target_vocab_size: The target vocabulary size. + draft_vocab_size: The draft vocabulary size. + cache_dir: The directory to use for caching the vocab mapping file. + cache_key: The key to use for caching the vocab mapping file. + + Returns: + The path to the vocab mapping file. + """ + # prepare cache directory + os.makedirs(cache_dir, exist_ok=True) + vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") + + if os.path.exists(vocab_mapping_path): + print(f"Loading vocab mapping from the cached file at: {vocab_mapping_path}") + return vocab_mapping_path + + # we first count the frequency of effective tokens in the dataset + token_dict = Counter() + for input_ids, loss_mask in tqdm( + zip(dataset["input_ids"], dataset["loss_mask"]), + total=len(dataset), + desc="Counting tokens for vocab mapping", + ): + masked_ids = input_ids[loss_mask == 1] + unique_ids, counts = masked_ids.unique(return_counts=True) + batch_token_dict = dict(zip(unique_ids.tolist(), counts.tolist())) + token_dict.update(batch_token_dict) + + # generate the d2t and t2d mapping + d2t, t2d = process_token_dict_to_mappings( + token_dict, + draft_vocab_size, + target_vocab_size, + ) + + vocab_mapping = { + "d2t": d2t, + "t2d": t2d, + } + torch.save(vocab_mapping, vocab_mapping_path) + print(f"Saved vocab mapping to: {vocab_mapping_path}") + return vocab_mapping_path + + +def process_token_dict_to_mappings( + token_dict: Counter, + draft_vocab_size: int, + target_vocab_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Process token_dict to create d2t and t2d mappings, with optional caching. + + Args: + token_dict: A Counter object mapping token ids to their frequencies. + draft_vocab_size: The size of the draft vocabulary. + target_vocab_size: The size of the target vocabulary. + + Returns: + A tuple containing: + - d2t: A tensor mapping draft token ids to target token ids. + - t2d: A tensor mapping target token ids to draft token ids. + """ + if len(token_dict) < draft_vocab_size: + existing_tokens = set(token_dict.keys()) + missing_tokens = set(range(draft_vocab_size)) - existing_tokens + for token in missing_tokens: + token_dict[token] = 0 + if len(token_dict) >= draft_vocab_size: + break + print(f"Added missing tokens to reach draft vocab size: {draft_vocab_size}") + print(f"Total tokens after addition: {len(token_dict)}") + total_frequency = sum(token_dict.values()) + top_N = token_dict.most_common(draft_vocab_size) + top_N_frequency_sum = sum(freq for key, freq in top_N) + + if total_frequency == 0: + print( + "Warning: Total token frequency is zero. All tokens will have zero ratio." + ) + top_N_ratio = 0.0 + else: + top_N_ratio = top_N_frequency_sum / total_frequency + + print(f"top {draft_vocab_size} token frequency ratio: {top_N_ratio:.2%}") + used_tokens = [key for key, freq in top_N] + used_tokens.sort() + + d2t = [used_tokens[i] - i for i in range(len(used_tokens))] + t2d = [i in used_tokens for i in range(target_vocab_size)] + d2t = torch.tensor(d2t) + t2d = torch.tensor(t2d) + + return d2t, t2d diff --git a/progress/SpecForge/specforge/data/template.py b/progress/SpecForge/specforge/data/template.py new file mode 100644 index 0000000000000000000000000000000000000000..4803db9af06fb4feb4d6cb5b577ca5766be82a4b --- /dev/null +++ b/progress/SpecForge/specforge/data/template.py @@ -0,0 +1,310 @@ +# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/chat_template.py#L13 +from typing import List + +from pydantic import BaseModel + + +class ChatTemplate(BaseModel): + """ + This is a dataclass for the chat template. + + Args: + assistant_header(str): The header for the assistant. + user_header(str): The header for the user. + system_prompt(str): The system prompt. + end_of_turn_token(str): The end token of a turn of conversation. + """ + + assistant_header: str | None + user_header: str | None + system_prompt: str | None + end_of_turn_token: str | None + parser_type: str = "general" + assistant_pattern_type: str = "general" + enable_thinking: bool = False + + +class TemplateRegistry: + """ + This is a registry for the chat template. Sgl-spec will register some common chat templates here. + If you have a custom chat template, you can register it via the example below. + + Example: + ```python + from specforge.data.template import TEMPLATE_REGISTRY, ChatTemplate + TEMPLATE_REGISTRY.register( + name="custom", + template=ChatTemplate( + assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n", + user_header="<|start_header_id|>user<|end_header_id|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|eot_id|>" + ) + ) + ``` + """ + + def __init__(self): + self.templates = {} + + def register(self, name: str, template: ChatTemplate, override: bool = False): + """ + Register a chat template for a model type. + + Args: + name(str): The name of the chat template. + template(ChatTemplate): The chat template. + override(bool): Whether to override the existing template, default to False + """ + assert ( + not override and name not in self.templates + ), f"Chat template for the model type {name} has already been registered" + self.templates[name] = template + + def get(self, name: str) -> ChatTemplate: + """ + Get the chat template for a model type. + + Args: + name(str): The name of the chat template. + + Returns: + ChatTemplate: The chat template. + """ + return self.templates[name] + + def get_all_template_names(self) -> List[str]: + """ + Get all the template names. + + Returns: + List[str]: The list of template names. + """ + return list(self.templates.keys()) + + +# global registry +TEMPLATE_REGISTRY = TemplateRegistry() + +# Register the common template here +TEMPLATE_REGISTRY.register( + name="llama3", + template=ChatTemplate( + assistant_header="<|start_header_id|>assistant<|end_header_id|>\n\n", + user_header="<|start_header_id|>user<|end_header_id|>", + system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.", + end_of_turn_token="<|eot_id|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="llama4", + template=ChatTemplate( + assistant_header="<|header_start|>assistant<|header_end|>\n\n", + user_header="<|header_start|>user<|header_end|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|eot|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen2-vl", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi3", + template=ChatTemplate( + assistant_header="<|assistant|>\n", + user_header="<|user|>\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi4", + template=ChatTemplate( + assistant_header="<|im_start|>assistant<|im_sep|>", + user_header="<|im_start|>user<|im_sep|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="phi4-mini", + template=ChatTemplate( + assistant_header="<|assistant|>", + user_header="<|user|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="gpt-oss-naive", + template=ChatTemplate( + assistant_header="<|start|>assistant<|channel|>analysis<|message|>", + user_header="<|start|>user<|message|>", + system_prompt=None, + end_of_turn_token="<|end|>", + ), +) + + +TEMPLATE_REGISTRY.register( + name="gpt-oss", + template=ChatTemplate( + assistant_header=None, # the headers are not applicable to openai-harmony's channel tags + user_header=None, + system_prompt=None, + end_of_turn_token=None, + parser_type="openai-harmony", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-r1-distill", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + end_of_turn_token=None, + system_prompt=None, + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen3-thinking", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, + ), +) + + +TEMPLATE_REGISTRY.register( + name="qwen3-instruct", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n\n\n\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="qwen3-next-thinking", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="kimi-k2-thinking", + template=ChatTemplate( + assistant_header="<|im_assistant|>assistant<|im_middle|>", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="kimi-k2-instruct", + template=ChatTemplate( + assistant_header="<|im_assistant|>assistant<|im_middle|>", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-v3", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|end▁of▁sentence|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="ling-flash-2.0", + template=ChatTemplate( + assistant_header="ASSISTANT", + user_header="HUMAN", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|role_end|>", + ), +) + +TEMPLATE_REGISTRY.register( + name="deepseek-v32", + template=ChatTemplate( + assistant_header="<|Assistant|>", + user_header="<|User|>", + system_prompt="", + end_of_turn_token="<|end▁of▁sentence|>", + parser_type="thinking", + enable_thinking=True, + ), +) + +TEMPLATE_REGISTRY.register( + name="gemma", + template=ChatTemplate( + assistant_header="model\n", + user_header="user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="\n", + ), +) + +TEMPLATE_REGISTRY.register( + name="longcat", + template=ChatTemplate( + assistant_header=" ASSISTANT:", + user_header=" USER:", + system_prompt="You are a helpful assistant.", + end_of_turn_token="", + assistant_pattern_type="longcat", + ), +) + +TEMPLATE_REGISTRY.register( + name="longcat_xml", + template=ChatTemplate( + assistant_header="", + user_header="", + system_prompt="You are a helpful assistant.", + end_of_turn_token="", + ), +) diff --git a/progress/SpecForge/specforge/data/utils.py b/progress/SpecForge/specforge/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..93fd6f58a6adf06051a14a58e5999bc35dad4dd4 --- /dev/null +++ b/progress/SpecForge/specforge/data/utils.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + +from datasets import Dataset +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group + + +class DataCollatorWithPadding: + """ + Datacollator that will dynamically pad the inputs for batching. + """ + + def __init__(self): + self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) + self.ulysses_degree = torch.distributed.get_world_size(get_sp_ulysses_group()) + + def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad to the longest sequence in the batch. + + Args: + intensors: (B, n, S) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N, S) + """ + B, n, S = intensors.shape + padding_tensor = torch.zeros( + B, N - n, S, dtype=intensors.dtype, device=intensors.device + ) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad 2D tensor to the longest sequence in the batch. + + Args: + intensors: (B, n) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N) + """ + B, n = intensors.shape + padding_tensor = torch.zeros( + B, N - n, dtype=intensors.dtype, device=intensors.device + ) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate a batch of features. + + Args: + features: A list of features, where each feature is a dictionary containing: + - input_ids: torch.Tensor of shape (n,) + - attention_mask: torch.Tensor of shape (n,) + - loss_mask: torch.Tensor of shape (n,) + + Returns: + A dictionary containing: + - input_ids: torch.Tensor of shape (B, N) + - attention_mask: torch.Tensor of shape (B, N) + - loss_mask: torch.Tensor of shape (B, N) + """ + max_length = max(item["input_ids"].shape[1] for item in features) + + # pad for sequence parrel + max_length = ( + (max_length + self.sp_degree - 1) // self.sp_degree + ) * self.sp_degree + # position max len, ulysses do not need chuck position ids + position_max_len = max_length * self.ulysses_degree + + batch_input_ids = torch.cat( + [self.paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [ + self.paddingtensor2D(item["attention_mask"], max_length) + for item in features + ] + ) + batch_loss_mask = torch.cat( + [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + if "position_ids" in features[0]: + batch_position_ids = torch.cat( + [ + self.paddingtensor2D(item["position_ids"], position_max_len) + for item in features + ] + ) + else: + batch_position_ids = None + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "hidden_state": None, + "target": None, + } + if batch_position_ids is not None: + batch["position_ids"] = batch_position_ids + if all("hidden_state" in item for item in features): + assert all( + "target" in item for item in features + ), "target is required when hidden_state is provided" + if self.sp_degree > 1: # USP mode + batch["hidden_state"] = torch.cat( + [item["hidden_state"] for item in features] + ) + else: + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) + batch["target"] = torch.cat( + [self.paddingtensor(item["target"], max_length) for item in features] + ) + return batch + + +class VlmDataCollatorWithPadding: + """ + Datacollator that will dynamically pad the inputs for batching. + """ + + def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad to the longest sequence in the batch. + + Args: + intensors: (B, n, S) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N, S) + """ + B, n, S = intensors.shape + padding_tensor = torch.zeros(B, N - n, S, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: + """ + Pad 2D tensor to the longest sequence in the batch. + + Args: + intensors: (B, n) + N: the length to pad to, N >= n + + Returns: + outtensors: (B, N) + """ + B, n = intensors.shape + padding_tensor = torch.zeros(B, N - n, dtype=intensors.dtype) + outtensors = torch.cat((intensors, padding_tensor), dim=1) + return outtensors + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate a batch of features. + + Args: + features: A list of features, where each feature is a dictionary containing: + - input_ids: torch.Tensor of shape (n,) + - attention_mask: torch.Tensor of shape (n,) + - loss_mask: torch.Tensor of shape (n,) + - pixel_values: torch.Tensor of shape (grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size) + - image_grid_thw: torch.Tensor of shape (3,) + + Returns: + A dictionary containing: + - input_ids: torch.Tensor of shape (B, N) + - attention_mask: torch.Tensor of shape (B, N) + - loss_mask: torch.Tensor of shape (B, N) + """ + max_length = max(item["input_ids"].shape[1] for item in features) + batch_input_ids = torch.cat( + [self.paddingtensor2D(item["input_ids"], max_length) for item in features] + ) + batch_attention_mask = torch.cat( + [ + self.paddingtensor2D(item["attention_mask"], max_length) + for item in features + ] + ) + batch_loss_mask = torch.cat( + [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] + ) + batch_pixel_values = torch.cat( + [item["pixel_values"] for item in features], dim=0 + ) + batch_image_grid_thw = torch.cat( + [item["image_grid_thw"] for item in features], dim=0 + ) + batch = { + "input_ids": batch_input_ids, + "attention_mask": batch_attention_mask, + "loss_mask": batch_loss_mask, + "pixel_values": batch_pixel_values, + "image_grid_thw": batch_image_grid_thw, + "hidden_state": None, + "target": None, + } + if all("hidden_state" in item for item in features): + assert all( + "target" in item for item in features + ), "target is required when hidden_state is provided" + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) + batch["target"] = torch.cat( + [self.paddingtensor(item["target"], max_length) for item in features] + ) + return batch + + +def prepare_dp_dataloaders( + dataset: Dataset, + batch_size: int, + num_workers: int = 4, + process_group: Optional[dist.ProcessGroup] = None, + pin_memory: Optional[bool] = False, + shuffle: Optional[bool] = False, + is_vlm: Optional[bool] = False, + prefetch_factor: Optional[int] = 2, + **dataloader_kwargs, +) -> DataLoader: + """ + Prepare dataloader for distributed data parallel training. + + Args: + dataset: The dataset to load data from. + batch_size: The batch size for each GPU. + num_workers: The number of workers for data loading. + process_group: The process group for distributed training. + pin_memory: Whether to pin memory for data loading. + shuffle: Whether to shuffle the dataset. + is_vlm: Whether the dataset is a vision-language model dataset. + **dataloader_kwargs: Additional keyword arguments for the DataLoader. + + Returns: + A DataLoader for the dataset. + """ + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle + ) + if is_vlm: + datacollator_cls = VlmDataCollatorWithPadding + else: + datacollator_cls = DataCollatorWithPadding + + if num_workers == 0: + prefetch_factor = None + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + collate_fn=datacollator_cls(), + drop_last=True, + **dataloader_kwargs, + ) + return dataloader + + +def parse_harmony_message_content(content): + """ + 解析 content 字符串中的 Harmony 格式。 + 如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表; + 否则,返回原内容并标记为默认 channel。 + """ + # 匹配 <|channel|>xxx<|message|>yyy<|end|> + pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>" + matches = re.findall(pattern, content, re.DOTALL) + + if not matches: + # 如果没有匹配到 Harmony 标签,视作普通文本 + return [{"channel": "text", "content": content}] + + results = [] + for channel, msg_body in matches: + results.append({"channel": channel.strip(), "content": msg_body.strip()}) + return results + + +def process_harmony_conversations(conversation): + """ + 处理传入的 list[list[dict]] 结构 + """ + new_conversation = [] + for msg in conversation: + role = msg.get("role") + original_content = msg.get("content", "") + + # 解析 content 中的 Harmony 结构 + segments = parse_harmony_message_content(original_content) + + # 为每个解析出的通道生成一个新的消息字典 + for seg in segments: + new_msg = { + "role": role, + "channel": seg["channel"], # 新增字段标识通道 + "content": seg["content"], + } + new_conversation.append(new_msg) + + return new_conversation diff --git a/progress/SpecForge/specforge/distributed.py b/progress/SpecForge/specforge/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5e882c4d69bc2cf8e03afe4fc05f3d60bdc3c6 --- /dev/null +++ b/progress/SpecForge/specforge/distributed.py @@ -0,0 +1,245 @@ +from datetime import timedelta +from typing import Any, Optional + +import torch +import torch.distributed as dist +from yunchang.globals import PROCESS_GROUP, set_seq_parallel_pg + +from specforge.utils import print_with_rank + +_DEVICE_MESH = None +_TP_DEVICE_MESH = None +_TP_GROUP = None +_DP_DEVICE_MESH = None +_DP_GROUP = None +_DRAFT_DP_GROUP = None +_DRAFT_SP_GROUP = None +_SP_ULYSSES_GROUP = None +_SP_RING_GROUP = None + + +def get_tp_group(): + global _TP_GROUP + return _TP_GROUP + + +def get_dp_group(): + global _DP_GROUP + return _DP_GROUP + + +def get_draft_dp_group(): + global _DRAFT_DP_GROUP + return _DRAFT_DP_GROUP + + +def get_draft_sp_group(): + global _DRAFT_SP_GROUP + return _DRAFT_SP_GROUP + + +def get_device_mesh(): + global _DEVICE_MESH + return _DEVICE_MESH + + +def get_tp_device_mesh(): + global _TP_DEVICE_MESH + return _TP_DEVICE_MESH + + +def get_dp_device_mesh(): + global _DP_DEVICE_MESH + return _DP_DEVICE_MESH + + +def get_sp_ulysses_group(): + global _SP_ULYSSES_GROUP + return _SP_ULYSSES_GROUP + + +def get_sp_ring_group(): + global _SP_RING_GROUP + return _SP_RING_GROUP + + +def init_distributed( + timeout: int = 10, tp_size: int = 1, sp_ulysses_size: int = 1, sp_ring_size: int = 1 +): + """Initialize distributed training. + + Args: + timeout(int): Timeout for collective communication in minutes + tp_size(int): The degree of tensor parallelism + """ + dist.init_process_group(backend="nccl", timeout=timedelta(minutes=timeout)) + local_rank = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + print_with_rank(f"bind to device {local_rank}") + + world_size = dist.get_world_size() + dp_size = world_size // tp_size + assert ( + world_size == tp_size * dp_size + ), f"world size must be divisible by tp size, now {world_size=}, {(tp_size * dp_size)=} " + + device_mesh = dist.device_mesh.init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + + assert ( + world_size % (sp_ulysses_size * sp_ring_size) == 0 + ), f"World size ({world_size}) cannot be evenly divided by total SP size ({sp_ulysses_size*sp_ring_size})" + + draft_dp_size = world_size // (sp_ulysses_size * sp_ring_size) + draft_device_mesh = dist.device_mesh.init_device_mesh( + "cuda", + (draft_dp_size, sp_ulysses_size * sp_ring_size), + mesh_dim_names=("draft_dp", "sp"), + ) + set_seq_parallel_pg(sp_ulysses_size, sp_ring_size, dist.get_rank(), world_size) + + print_with_rank(f"device mesh: {device_mesh}") + tp_group = device_mesh.get_group("tp") + dp_group = device_mesh.get_group("dp") + + sp_ulysses_group = PROCESS_GROUP.ULYSSES_PG + sp_ring_group = PROCESS_GROUP.RING_PG + # we need to create a 1D submesh + tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type="cuda") + + global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH, _DP_DEVICE_MESH, _SP_RING_GROUP, _SP_ULYSSES_GROUP, _DRAFT_DP_GROUP, _DRAFT_SP_GROUP + _DEVICE_MESH = device_mesh + _TP_GROUP = tp_group + _TP_DEVICE_MESH = tp_device_mesh + _SP_ULYSSES_GROUP = sp_ulysses_group + _SP_RING_GROUP = sp_ring_group + _DP_GROUP = dp_group + _DRAFT_DP_GROUP = draft_device_mesh.get_group("draft_dp") + _DRAFT_SP_GROUP = draft_device_mesh.get_group("sp") + _DP_DEVICE_MESH = dist.DeviceMesh.from_group(dp_group, device_type="cuda") + + +def destroy_distributed(): + global _TP_GROUP, _DP_GROUP, _SP_ULYSSES_GROUP, _SP_RING_GROUP, _DRAFT_DP_GROUP + dist.destroy_process_group(_TP_GROUP) + dist.destroy_process_group(_DP_GROUP) + dist.destroy_process_group(_SP_ULYSSES_GROUP) + dist.destroy_process_group(_SP_RING_GROUP) + dist.destroy_process_group(_DRAFT_DP_GROUP) + dist.destroy_process_group(_DRAFT_SP_GROUP) + dist.destroy_process_group() + + +def shard_tensor( + tensor: torch.Tensor, process_group: dist.ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + rank = dist.get_rank(process_group) + size = dist.get_world_size(process_group) + return tensor.chunk(size, dim=dim)[rank].contiguous() + + +def gather_tensor( + tensor: torch.Tensor, process_group: dist.ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + size = dist.get_world_size(process_group) + obj_list = [torch.empty_like(tensor) for _ in range(size)] + dist.all_gather(obj_list, tensor, group=process_group) + gather_tensor = torch.cat(obj_list, dim=dim) + return gather_tensor + + +def all_gather_tensor( + local_tensor: torch.Tensor, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): + sp_world_size = dist.get_world_size(group=group) + output_shape = list(local_tensor.shape) + output_shape[0] = output_shape[0] * sp_world_size + output = torch.empty( + output_shape, dtype=local_tensor.dtype, device=local_tensor.device + ) + dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) + return output + + +# Adapted from https://github.com/volcengine/verl/blob/a0e8e4472b8b472409defb0c8fcc5162301450af/verl/utils/ulysses.py#L194 +class Gather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_tensor: torch.Tensor, + gather_dim: int, + grad_scaler: bool = True, + async_op=False, + ) -> torch.Tensor: + ctx.group = group + ctx.gather_dim = gather_dim + ctx.grad_scaler = grad_scaler + ctx.async_op = async_op + + sp_world_size = dist.get_world_size(group=group) + ctx.sp_world_size = sp_world_size + + sp_rank = dist.get_rank(group=group) + ctx.sp_rank = sp_rank + + local_shape = list(local_tensor.size()) + split_size = local_shape[0] + part_size = local_shape[gather_dim] # store original size + ctx.part_size = part_size + + output = all_gather_tensor(local_tensor, group, async_op) + return torch.cat(output.split(split_size, dim=0), dim=gather_dim) + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Any: + if ctx.grad_scaler: + grad_output = grad_output * ctx.sp_world_size + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ + ctx.sp_rank + ].contiguous(), + None, + None, + None, + None, + ) + + +def gather_outputs_and_unpad( + x: torch.Tensor, + gather_dim: int, + grad_scaler: bool = True, + group: Optional[dist.ProcessGroup] = None, +): + """ + Gather a tensor across a process group and optionally unpad its padded elements. + + Args: + x (Tensor): Input tensor to gather. + gather_dim (int): Dimension along which to gather across ranks. + grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. + group (ProcessGroup, optional): Process group for gathering. If None, uses + `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. + + Returns: + Tensor: The gathered tensor, with padding removed if requested. + """ + if not group: + group = get_draft_sp_group() + if torch.distributed.get_world_size(group) == 1: + return x + x = Gather.apply(group, x, gather_dim, grad_scaler) + return x + + +def is_tp_rank_0(): + """Return True if current process is rank 0 in its TP group.""" + tp_group = get_tp_group() + if tp_group is None: + return True + return dist.get_rank(group=tp_group) == 0 diff --git a/progress/SpecForge/specforge/layers/__init__.py b/progress/SpecForge/specforge/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b71718d39de7248cd0c33732c920f129ddd40001 --- /dev/null +++ b/progress/SpecForge/specforge/layers/__init__.py @@ -0,0 +1,10 @@ +from .embedding import VocabParallelEmbedding +from .linear import ColumnParallelLinear, RowParallelLinear +from .lm_head import ParallelLMHead + +__all__ = [ + "VocabParallelEmbedding", + "ColumnParallelLinear", + "RowParallelLinear", + "ParallelLMHead", +] diff --git a/progress/SpecForge/specforge/layers/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df5467de25f62c3f33d2cbfdcb1a9903739a8f9 Binary files /dev/null and b/progress/SpecForge/specforge/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/__pycache__/embedding.cpython-311.pyc b/progress/SpecForge/specforge/layers/__pycache__/embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dfae930725f39c3410799d4217b995652aa5e33 Binary files /dev/null and b/progress/SpecForge/specforge/layers/__pycache__/embedding.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/__pycache__/linear.cpython-311.pyc b/progress/SpecForge/specforge/layers/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..724a5ca254360bfd37c16f44f83ecc0498cac198 Binary files /dev/null and b/progress/SpecForge/specforge/layers/__pycache__/linear.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/__pycache__/lm_head.cpython-311.pyc b/progress/SpecForge/specforge/layers/__pycache__/lm_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8e913fcbbeffac23ac94d08851b9ce2a54aad69 Binary files /dev/null and b/progress/SpecForge/specforge/layers/__pycache__/lm_head.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/embedding.py b/progress/SpecForge/specforge/layers/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..336d5776c5d1585a0a8ec0f3a24d2458d77703aa --- /dev/null +++ b/progress/SpecForge/specforge/layers/embedding.py @@ -0,0 +1,132 @@ +import math +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class VocabParallelEmbedding(nn.Module): + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + + # tp-realted + self.tp_group = get_tp_group() + self.tp_rank = dist.get_rank(self.tp_group) + self.tp_size = dist.get_world_size(self.tp_group) + + # deal with the case where the embedding is not divisible by the TP size + self.num_embeddings_per_shard = math.ceil(num_embeddings / self.tp_size) + self.padded_num_embeddings = ( + self.num_embeddings_per_shard * self.tp_size - self.num_embeddings + ) + self.vocab_start_index = self.tp_rank * self.num_embeddings_per_shard + self.vocab_end_index = min( + self.vocab_start_index + self.num_embeddings_per_shard, + self.num_embeddings, + ) + + if ( + padding_idx is not None + and padding_idx >= self.vocab_start_index + and padding_idx < self.vocab_end_index + ): + self.padding_idx = padding_idx - self.vocab_start_index + else: + self.padding_idx = None + + self.weight = nn.Parameter( + torch.empty( + (self.num_embeddings_per_shard, self.embedding_dim), **factory_kwargs + ), + requires_grad=True, + ) + self.reset_parameters() + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "weight" in state_dict: + value = state_dict["weight"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, 0, 0, padding_size)) + state_dict["weight"] = shard_tensor(value, self.tp_group, 0) + + def reset_parameters(self) -> None: + torch.nn.init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def generate_mask(self, input_): + # generate the mask for the vocab which is only owned by the current rank + mask = (input_ >= self.vocab_start_index) & (input_ < self.vocab_end_index) + return mask + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + mask = self.generate_mask(input_) + masked_input = input_ - self.vocab_start_index + masked_input[~mask] = 0 + else: + masked_input = input_ + + output_parallel = F.embedding( + masked_input, + self.weight, + padding_idx=self.padding_idx, + max_norm=self.max_norm, + norm_type=self.norm_type, + scale_grad_by_freq=self.scale_grad_by_freq, + sparse=self.sparse, + ) + + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[~mask] = 0 + # Reduce across all the model parallel GPUs. + dist.all_reduce(output_parallel, op=dist.ReduceOp.SUM, group=self.tp_group) + output = output_parallel + else: + output = output_parallel + return output diff --git a/progress/SpecForge/specforge/layers/linear.py b/progress/SpecForge/specforge/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c512d2139d32b233efa5795a234a1426d0e5e4 --- /dev/null +++ b/progress/SpecForge/specforge/layers/linear.py @@ -0,0 +1,204 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class RowParallelLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + kv_head_replicas=False, + layout_type: str = "normal", + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.layout_type = layout_type + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + self.in_features = in_features + self.out_features = out_features + + if kv_head_replicas: + self.in_features_per_shard = in_features + else: + self.in_features_per_shard = in_features // self.tp_size + self.weight = nn.Parameter( + torch.empty(self.out_features, self.in_features_per_shard, **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + """ + This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type. + """ + if self.layout_type == "normal": + self.handle_normal_layout(state_dict, *args) + else: + raise ValueError(f"Invalid layout type: {self.layout_type}") + + def handle_normal_layout(self, state_dict, *args): + # shard the weights + if "weight" in state_dict: + state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, -1) + + if "bias" in state_dict and self.tp_rank != 0: + state_dict["bias"] = torch.zeros_like(state_dict["bias"]) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"RowParallelLinear(in_features={self.in_features_per_shard}, out_features={self.out_features}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" + + +class ColumnParallelLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + layout_type: str = "normal", + kv_head_replicas=False, + kv_head_idx=None, + total_num_kv_heads=None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.layout_type = layout_type + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + self.in_features = in_features + self.out_features = out_features + self.kv_head_replicas = kv_head_replicas + self.kv_head_idx = kv_head_idx + self.total_num_kv_heads = total_num_kv_heads + if self.kv_head_replicas: + self.out_features_per_shard = out_features + else: + self.out_features_per_shard = out_features // self.tp_size + + self.weight = nn.Parameter( + torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs) + ) + if bias: + self.bias = nn.Parameter( + torch.empty(self.out_features_per_shard, **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + """ + This is a state dict hook to be triggered before loading the state dict. This will shard the weights and biases according to the layout type. + """ + if self.kv_head_replicas: + assert self.kv_head_idx is not None + assert self.layout_type == "normal" + self.handle_kv_head_replicas(state_dict, *args) + else: + if self.layout_type == "normal": + self.handle_normal_layout(state_dict, *args) + elif self.layout_type == "merged_qkv": + self.handle_merged_qkv(state_dict, *args) + elif self.layout_type == "gate_up": + self.handle_gate_up_layout(state_dict, *args) + else: + raise ValueError(f"Invalid layout type: {self.layout_type}") + + def handle_kv_head_replicas(self, state_dict, *args): + """ + This is a special case for GQA where the key/value are split according to the number of kv heads and the head which belongs to this rank. + As the TP size is larger than the number of kv heads, we only keep one kv head per rank. + """ + if "weight" in state_dict: + state_dict["weight"] = state_dict["weight"].chunk( + self.total_num_kv_heads, dim=0 + )[self.kv_head_idx] + if "bias" in state_dict and state_dict["bias"] is not None: + state_dict["bias"] = state_dict["bias"].chunk( + self.total_num_kv_heads, dim=0 + )[self.kv_head_idx] + + def handle_normal_layout(self, state_dict, *args): + """ + This shards the weights and biases along the column dimension. + """ + # shard the weights + if "weight" in state_dict: + state_dict["weight"] = shard_tensor(state_dict["weight"], self.tp_group, 0) + + if "bias" in state_dict and state_dict["bias"] is not None: + state_dict["bias"] = shard_tensor(state_dict["bias"], self.tp_group, 0) + + def handle_gate_up_layout(self, state_dict, *args): + """ + This handles the gate_up layout where the gate and up weights are concatenated along the column dimension. + """ + if "weight" in state_dict: + gate, up = state_dict["weight"].chunk(2, dim=0) + gate = shard_tensor(gate, self.tp_group, 0) + up = shard_tensor(up, self.tp_group, 0) + state_dict["weight"] = torch.cat((gate, up), dim=0) + + if "bias" in state_dict and state_dict["bias"] is not None: + gate, up = state_dict["bias"].chunk(2, dim=0) + gate = shard_tensor(gate, self.tp_group, 0) + up = shard_tensor(up, self.tp_group, 0) + state_dict["bias"] = torch.cat((gate, up), dim=0) + + def handle_merged_qkv(self, state_dict, *args): + """ + This handles the merged QKV layout where the q, k, v weights are concatenated along the column dimension. + """ + if "weight" in state_dict: + # need to split into qkv and take the correct chunk for the rank + q, k, v = state_dict["weight"].chunk(3, dim=0) + q = shard_tensor(q, self.tp_group, 0) + k = shard_tensor(k, self.tp_group, 0) + v = shard_tensor(v, self.tp_group, 0) + state_dict["weight"] = torch.cat((q, k, v), dim=0) + + if "bias" in state_dict and state_dict["bias"] is not None: + q, k, v = state_dict["bias"].chunk(3, dim=0) + q = shard_tensor(q, self.tp_group, 0) + k = shard_tensor(k, self.tp_group, 0) + v = shard_tensor(v, self.tp_group, 0) + state_dict["bias"] = torch.cat((q, k, v), dim=0) + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"ColumnParallelLinear(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" diff --git a/progress/SpecForge/specforge/layers/lm_head.py b/progress/SpecForge/specforge/layers/lm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d50b089da761b2b85396fbf8b40aa1a5d65133 --- /dev/null +++ b/progress/SpecForge/specforge/layers/lm_head.py @@ -0,0 +1,109 @@ +import math +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.distributed import get_tp_group, shard_tensor + + +class ParallelLMHead(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.in_features = in_features + self.out_features = out_features + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.tp_rank = dist.get_rank(self.tp_group) + + # tp-related + self.out_features_per_shard = math.ceil(out_features / self.tp_size) + self.padded_out_features = ( + self.out_features_per_shard * self.tp_size - out_features + ) + assert ( + self.out_features_per_shard * self.tp_size + == out_features + self.padded_out_features + ) + + self.weight = nn.Parameter( + torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs) + ) + self.bias = ( + nn.Parameter(torch.zeros(self.out_features_per_shard, **factory_kwargs)) + if bias + else None + ) + + # init params + self.reset_parameters() + + # handle weight loading + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "weight" in state_dict: + value = state_dict["weight"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, 0, 0, padding_size)) + state_dict["weight"] = shard_tensor(value, self.tp_group, 0) + + if "bias" in state_dict: + value = state_dict["bias"] + + # pad this value if it is not divisible by the TP size + if value.shape[0] % self.tp_size != 0: + padding_size = self.tp_size - value.shape[0] % self.tp_size + value = F.pad(value, (0, padding_size)) + state_dict["bias"] = shard_tensor(value, self.tp_group, 0) + + def forward(self, hidden: torch.Tensor, gather_output: bool = False): + """ + hidden: [B, T, H] or [N, H] + returns: + - if gather_output=False: local logits [*, local_vocab] and (start,end) for stitching + - if gather_output=True: full logits [*, vocab] via all-gather (use for inference) + """ + orig_shape = hidden.shape + hidden = hidden.reshape(-1, self.in_features) # [N, H] + + local_logits = hidden @ self.weight.T # [N, local_vocab] + + if self.bias is not None: + local_logits = local_logits + self.bias + + if not gather_output or self.tp_size == 1: + return local_logits.view( + *orig_shape[:-1], self.out_features_per_shard + ).contiguous() + else: + # all-gather shards along vocab dim + chunks = [torch.empty_like(local_logits) for _ in range(self.tp_size)] + dist.all_gather(chunks, local_logits, group=self.tp_group) + full = torch.cat(chunks, dim=-1)[ + :, : self.out_features + ] # trim padding from ceil-div + return full.view(*orig_shape[:-1], self.out_features).contiguous() + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self): + return f"ParallelLMHead(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" diff --git a/progress/SpecForge/specforge/layers/ring/__init__.py b/progress/SpecForge/specforge/layers/ring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0a04a8f5eae08db74b9697dd8d1e2ae946edc9 --- /dev/null +++ b/progress/SpecForge/specforge/layers/ring/__init__.py @@ -0,0 +1,12 @@ +# adapt from https://github.com/feifeibear/long-context-attention/tree/main/yunchang +from .ring_flash_attn import ( + ring_flash_attn_func, + ring_flash_attn_kvpacked_func, + ring_flash_attn_qkvpacked_func, +) + +__all__ = [ + "ring_flash_attn_func", + "ring_flash_attn_kvpacked_func", + "ring_flash_attn_qkvpacked_func", +] diff --git a/progress/SpecForge/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b4ac3ca09223c6bec8db81d10b9a9c2c136cc67 Binary files /dev/null and b/progress/SpecForge/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc b/progress/SpecForge/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea2cee1011f0e85a2360742d1f210116a91979f Binary files /dev/null and b/progress/SpecForge/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/ring/__pycache__/utils.cpython-311.pyc b/progress/SpecForge/specforge/layers/ring/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c93520b914bc1c4ffbb8c19e638b6e433872f78 Binary files /dev/null and b/progress/SpecForge/specforge/layers/ring/__pycache__/utils.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/layers/ring/ring_flash_attn.py b/progress/SpecForge/specforge/layers/ring/ring_flash_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3c89b7e4a33431b83d4abe7bb64fd7b52a396523 --- /dev/null +++ b/progress/SpecForge/specforge/layers/ring/ring_flash_attn.py @@ -0,0 +1,336 @@ +import torch +from yunchang.kernels import AttnType, select_flash_attn_impl + +from .utils import RingComm, update_out_and_lse + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, + attn_processor=None, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if not causal or step <= comm.rank: + fn = select_flash_attn_impl( + attn_type, stage="fwd-only", attn_processor=attn_processor + ) + block_out, block_lse = fn( + q, + k, + v, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal and step == 0, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=True and dropout_p > 0, + ) + if attn_type == AttnType.SPARSE_SAGE: + out, lse = block_out, block_lse + else: + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + if attn_type != AttnType.SPARSE_SAGE: + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + attn_type: AttnType = AttnType.FA, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + fn = select_flash_attn_impl(attn_type, stage="bwd-only") + fn( + dout, + q, + k, + v, + out, + softmax_lse, + block_dq_buffer, + block_dk_buffer, + block_dv_buffer, + dropout_p, + softmax_scale, + bwd_causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, + ) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk = next_dk + dv = next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk) + next_dv = d_kv_comm.send_recv(dv) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + group, + attn_type, + attn_processor, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=False, + attn_type=attn_type, + attn_processor=attn_processor, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + ctx.attn_type = attn_type + ctx.attn_processor = attn_processor + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + softcap=ctx.softcap, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + attn_type=ctx.attn_type, + ) + return ( + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, + attn_type: AttnType = AttnType.FA, + attn_processor=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + group, + attn_type, + attn_processor, + ) diff --git a/progress/SpecForge/specforge/layers/ring/utils.py b/progress/SpecForge/specforge/layers/ring/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14d6a7817dc9358d64e4b93465e1b1f33d870923 --- /dev/null +++ b/progress/SpecForge/specforge/layers/ring/utils.py @@ -0,0 +1,119 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "RingComm"] + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty( + (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device + ) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + # print(f"send_recv: empty_like {to_send.shape}") + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/progress/SpecForge/specforge/lr_scheduler.py b/progress/SpecForge/specforge/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3276c79806a13c5223f0042ffc3abf6cad52ce --- /dev/null +++ b/progress/SpecForge/specforge/lr_scheduler.py @@ -0,0 +1,260 @@ +from warnings import warn + +from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR +from torch.optim.lr_scheduler import LRScheduler as _LRScheduler + + +class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class TwoStageScheduler(_LRScheduler): + def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1): + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = { + key: value for key, value in self.__dict__.items() if key not in "optimizer" + } + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type( + state_dict["after_scheduler"] + ).__name__ + state_dict["after_scheduler_dict"] = state_dict[ + "after_scheduler" + ].state_dict() + del state_dict["after_scheduler"] + else: + raise NotImplementedError() + return state_dict + + def load_state_dict(self, state_dict): + if "after_scheduler_dict" not in state_dict: + warn( + "after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior." + ) + else: + self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"]) + state_dict = { + key: value + for key, value in state_dict.items() + if key not in ("after_scheduler_type", "after_scheduler_dict") + } + super().load_state_dict(state_dict) + + +class DelayerScheduler(TwoStageScheduler): + """Starts with a flat lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau) + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): + if delay_epochs < 0: + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") + self.delay_epochs = delay_epochs + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + + return self.base_lrs + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.delay_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(DelayerScheduler, self).step(epoch) + + +class WarmupScheduler(TwoStageScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs then applies + the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): + self.warmup_epochs = int(warmup_epochs) + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + return self.after_scheduler.get_lr() + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class WarmupDelayerScheduler(TwoStageScheduler): + """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule + until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + warmup_epochs (int): Number of epochs to linearly warmup lr until starting applying the scheduler. + delay_epochs (int): Number of epochs to keep the initial lr until starting applying the scheduler. + after_scheduler (:class:`torch.optim.lr_scheduler`): After target_epoch, use this scheduler. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1 + ): + if delay_epochs < 0: + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") + if warmup_epochs < 0: + raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") + self.warmup_epochs = warmup_epochs + self.delay_epochs = delay_epochs + super().__init__(optimizer, after_scheduler, last_epoch) + + def get_lr(self): + if self.last_epoch >= self.warmup_epochs + self.delay_epochs: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + # reset lr to base_lr + for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): + group["lr"] = base_lr + self.finished = True + with _enable_get_lr_call(self.after_scheduler): + return self.after_scheduler.get_lr() + elif self.last_epoch >= self.warmup_epochs: + return self.base_lrs + + return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs] + + def step(self, epoch=None): + if self.finished: + if epoch is None: + self.after_scheduler.step(None) + self._last_lr = self.after_scheduler.get_last_lr() + else: + self.after_scheduler.step(epoch - self.warmup_epochs) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super().step(epoch) + + +class CosineAnnealingLR(_CosineAnnealingLR): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, + optimizer, + total_steps: int, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs, + ): + super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch) + + +class CosineAnnealingWarmupLR(WarmupScheduler): + """Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied. + + Args: + optimizer (:class:`torch.optim.Optimizer`): Wrapped optimizer. + total_steps (int): Number of total training steps. + warmup_steps (int, optional): Number of warmup steps, defaults to 0. + eta_min (int, optional): Minimum learning rate, defaults to 0. + last_epoch (int, optional): The index of last epoch, defaults to -1. When last_epoch=-1, + the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. + """ + + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + eta_min: float = 0.0, + last_epoch: int = -1, + ): + base_scheduler = _CosineAnnealingLR( + optimizer, + total_steps - warmup_steps, + eta_min=eta_min, + last_epoch=last_epoch, + ) + super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/progress/SpecForge/specforge/modeling/__init__.py b/progress/SpecForge/specforge/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09999d60bc39243219b2c346154fe51ff4594dce --- /dev/null +++ b/progress/SpecForge/specforge/modeling/__init__.py @@ -0,0 +1,19 @@ +# from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel +from .auto import AutoDraftModelConfig, AutoEagle3DraftModel +from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .target.eagle3_target_model import ( + CustomEagle3TargetModel, + HFEagle3TargetModel, + SGLangEagle3TargetModel, + get_eagle3_target_model, +) + +__all__ = [ + "LlamaForCausalLMEagle3", + "SGLangEagle3TargetModel", + "HFEagle3TargetModel", + "CustomEagle3TargetModel", + "get_eagle3_target_model", + "AutoDraftModelConfig", + "AutoEagle3DraftModel", +] diff --git a/progress/SpecForge/specforge/modeling/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/modeling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4549331eb2f38533adcdb13508c666f6299e4d7e Binary files /dev/null and b/progress/SpecForge/specforge/modeling/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc b/progress/SpecForge/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc80ab67519d23e84a36dd278e2cbdc080f630fa Binary files /dev/null and b/progress/SpecForge/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/__pycache__/auto.cpython-311.pyc b/progress/SpecForge/specforge/modeling/__pycache__/auto.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7239917b511e2cc843a19bd2bc67496c13a3799 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/__pycache__/auto.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/_mask_utils.py b/progress/SpecForge/specforge/modeling/_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb200299e24ecf531c117618e55837c49facbe --- /dev/null +++ b/progress/SpecForge/specforge/modeling/_mask_utils.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) diff --git a/progress/SpecForge/specforge/modeling/auto.py b/progress/SpecForge/specforge/modeling/auto.py new file mode 100644 index 0000000000000000000000000000000000000000..1e48a43e7a62748f802671500b23adf74c6dd03a --- /dev/null +++ b/progress/SpecForge/specforge/modeling/auto.py @@ -0,0 +1,175 @@ +import json +import os +from typing import Optional, Union + +import torch +from transformers import AutoConfig +from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase +from transformers import ( + GptOssConfig, + Llama4Config, + Llama4TextConfig, + LlamaConfig, + Phi3Config, + PretrainedConfig, + Qwen2Config, + Qwen3Config, + Qwen3MoeConfig, + modeling_utils, +) + +from .draft.llama3_eagle import LlamaForCausalLMEagle3 +from .target.custom_backend import ( + GptOssForCausalLM, + Llama4ForCausalLM, + LlamaForCausalLM, + Phi3ForCausalLM, + Qwen2ForCausalLM, + Qwen3ForCausalLM, + Qwen3MoeForCausalLM, +) + + +class AutoEagle3DraftModel(AutoModelForCausalLMBase): + # the model mapping is currently hardcoded, we should support lazy model mapping via registry + _model_mapping = { + LlamaConfig: LlamaForCausalLMEagle3, + } + + @classmethod + def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs): + """ + This class method takes a configuration object and create its model based on the + _model_mapping class variable. + + Args: + config (PretrainedConfig): A configuration object. + + Returns: + A model instance. + """ + # get the model class from the + _model_cls = cls._model_mapping[type(config)] + model = _model_cls(config, **config_kwargs) + + # Convert model to specified dtype if provided + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + *model_args, + **kwargs, + ): + original_warn = modeling_utils.logger.warning + + def filtered_warning(msg): + if "embed_tokens.weight" in str(msg) and "initialized" in str(msg): + return + original_warn(msg) + + modeling_utils.logger.warning = filtered_warning + + try: + model = super().from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + finally: + modeling_utils.logger.warning = original_warn + + return model + + +class AutoDistributedTargetModel(AutoModelForCausalLMBase): + # the model mapping is currently hardcoded, we should support lazy model mapping via registry + _model_mapping = { + Llama4TextConfig: [Llama4ForCausalLM], + Qwen3MoeConfig: [Qwen3MoeForCausalLM], + Qwen2Config: [Qwen2ForCausalLM], + LlamaConfig: [LlamaForCausalLM], + Qwen3Config: [Qwen3ForCausalLM], + Phi3Config: [Phi3ForCausalLM], + GptOssConfig: [GptOssForCausalLM], + } + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **config_kwargs, + ): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + ) + + if isinstance(config, Llama4Config): + config = config.text_config + + assert ( + type(config) in cls._model_mapping + ), f"Unsupported config type: {type(config)}" + model_cls = cls._model_mapping[type(config)][0] + model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + **config_kwargs, + ) + + if device is not None: + model = model.to(device) + else: + model = model.cuda() + return model + + +class AutoDraftModelConfig: + + _config_mapping = { + "LlamaForCausalLMEagle3": LlamaConfig, + } + + @classmethod + def from_file(cls, config_path: str): + """ + This class method takes a configuration file path and create its configuration object based on the + _config_mapping class variable. + + Args: + config_path (str): A path to a configuration file. + + Returns: + A configuration object. + """ + with open(config_path, "r") as f: + config = json.load(f) + + if "tie_word_embeddings" in config: + print("Set draft model tie_word_embeddings to False") + config["tie_word_embeddings"] = False + + # check for architectures + architectures = config.get("architectures", None) + + if architectures is None: + raise ValueError("No architectures found in the config file") + + if len(architectures) != 1: + raise ValueError("Only one architecture is supported") + + architecture = architectures[0] + + if architecture not in cls._config_mapping: + raise ValueError(f"Architecture {architecture} not supported") + + # If draft_vocab_size is not in config or is None, set draft_vocab_size to vocab_size + if "draft_vocab_size" not in config or config["draft_vocab_size"] is None: + config["draft_vocab_size"] = config.get("vocab_size", None) + + return cls._config_mapping[architecture].from_dict(config) diff --git a/progress/SpecForge/specforge/modeling/draft/__init__.py b/progress/SpecForge/specforge/modeling/draft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdc7e2f6fa02407dc4e4bab9c4b1e252c10aa62 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/__init__.py @@ -0,0 +1,17 @@ +from .base import Eagle3DraftModel +from .dflash import ( + DFlashDraftModel, + build_target_layer_ids, + extract_context_feature, + sample, +) +from .llama3_eagle import LlamaForCausalLMEagle3 + +__all__ = [ + "Eagle3DraftModel", + "DFlashDraftModel", + "LlamaForCausalLMEagle3", + "build_target_layer_ids", + "extract_context_feature", + "sample", +] diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c698ade05722f32e47960f22f2fd96d60eb933 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/base.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ae0e8c1c442d133c68e59fe2ffc88ead75453b Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/base.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da8c5a7c466cff0eaf979dd9218cceaf69f5afef Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-313.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88247e10fbfe0918d4f103644b12d4d4d948b0f Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash.cpython-313.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1b405bd374b83641891e51b825f9721f7583562 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-313.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4777f52723e8b9f55d42df83ce9daabaafe273cd Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/dflash_lora.cpython-313.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424380688eca29c4c8405e48ac740dce34a072ce Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc b/progress/SpecForge/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e56fcf3d5fc210b904ddcd6549908fa55b217f31 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/draft/base.py b/progress/SpecForge/specforge/modeling/draft/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b5584a759d78a072903e0e76999b1674a62f0a88 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/base.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers.cache_utils import Cache +from transformers.modeling_utils import PreTrainedModel + +from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask + + +class Eagle3DraftModel(PreTrainedModel, ABC): + """ + This is the base class for the Eagle3 draft model implementation. The child class needs to implement + the abstract methods to support training with TTT. + """ + + @abstractmethod + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Embed the input ids. + """ + + @abstractmethod + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Project the concatenated hidden states from the high, medium and low layers to the target hidden size. + """ + + @abstractmethod + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Compute the logits of the draft model. + """ + + def prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + batch_size: int, + seq_length: int, + past_key_values_length: int, + ) -> torch.Tensor: + """ + Prepare the attention mask of the draft model. + """ + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if seq_length > 1: + combined_attention_mask = _make_causal_mask( + (batch_size, seq_length), + hidden_states.dtype, + device=hidden_states.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, hidden_states.dtype, tgt_len=seq_length + ).to(hidden_states.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + return combined_attention_mask + + @abstractmethod + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Cache] = None, + use_cache: bool = True, + ) -> torch.Tensor: + """ + The backbone of the draft model. + """ + + def freeze_embedding(self) -> None: + """ + Freeze the embeddings of the draft model so that they are not updated during training. + """ + self.embed_tokens.weight.requires_grad = False + + @torch.no_grad() + def load_embedding( + self, model_path: str, embedding_key: str = "model.embed_tokens.weight" + ) -> None: + """ + Load the embedding of the draft model. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + """ + if os.path.exists(model_path): + # model_path is a local directory + # check if there is file ending with index.json + glob_path = os.path.join(model_path, "*.index.json") + index_json_path = glob.glob(glob_path) + + if len(index_json_path) == 0: + # No index.json found, look for single model file + safetensors_path = os.path.join(model_path, "model.safetensors") + if os.path.exists(safetensors_path): + with safe_open(safetensors_path, framework="pt") as f: + self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) + return + + pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + state_dict = torch.load(pytorch_model_path, map_location="cpu") + self.embed_tokens.weight.copy_(state_dict[embedding_key]) + return + + raise FileNotFoundError( + f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" + ) + if len(index_json_path) > 1: + raise FileNotFoundError( + f"Multiple index.json files found in {model_path}" + ) + index_json_path = index_json_path[0] + + with open(index_json_path, "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][embedding_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open( + os.path.join(model_path, ckpt_file), framework="pt" + ) as f: + emb_tokens = f.get_tensor(embedding_key) + else: + state_dict = torch.load(os.path.join(model_path, ckpt_file)) + emb_tokens = state_dict[embedding_key] + self.embed_tokens.weight.copy_(emb_tokens) + else: + # this is the case where model_path is a huggingface repository + # we first need to locate its local cache + local_cache_path = snapshot_download(repo_id=model_path) + self.load_embedding(local_cache_path, embedding_key) + + def load_vocab_mapping(self, file_path: str) -> None: + """ + Load the vocab buffers of the draft model. + + Args: + file_path (str): The path to the vocab mapping file. + """ + assert hasattr(self, "t2d") and hasattr( + self, "d2t" + ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" + vocab_mapping = torch.load(file_path) + self.t2d.copy_(vocab_mapping["t2d"]) + self.d2t.copy_(vocab_mapping["d2t"]) + self.vocab_mapping_loaded = True diff --git a/progress/SpecForge/specforge/modeling/draft/dflash.py b/progress/SpecForge/specforge/modeling/draft/dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..0aea03fe130ec6070d8d94f8d11c37ee9885782f --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/dflash.py @@ -0,0 +1,379 @@ +from typing import Callable, Optional + +import torch +from torch import nn +from transformers import DynamicCache +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.qwen3.modeling_qwen3 import ( + ALL_ATTENTION_FUNCTIONS, + FlashAttentionKwargs, + GradientCheckpointingLayer, + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward, + rotate_half, +) +from typing_extensions import Tuple, Unpack + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3DFlashAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view( + bsz, ctx_len + q_len, -1, self.head_dim + ) + v = torch.cat([v_ctx, v_noise], dim=1).view( + bsz, ctx_len + q_len, -1, self.head_dim + ) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + attn_fn: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + return target_layer_ids + + +def extract_context_feature( + hidden_states: list[torch.Tensor], + layer_ids: Optional[list[int]], +) -> torch.Tensor: + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +class DFlashDraftModel(Qwen3PreTrainedModel): + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + Qwen3DFlashDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + dflash_config = getattr(config, "dflash_config", {}) or {} + self.target_layer_ids = dflash_config.get( + "target_layer_ids", + build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear( + len(self.target_layer_ids) * config.hidden_size, + config.hidden_size, + bias=False, + ) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.mask_token_id = dflash_config.get("mask_token_id", None) + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> CausalLMOutputWithPast: + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + @torch.inference_mode() + def spec_generate( + self, + target: nn.Module, + input_ids: torch.LongTensor, + max_new_tokens: int, + stop_token_ids: list[int], + temperature: float, + ): + self.eval() + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + block_size = self.block_size + output_ids = torch.full( + (1, max_length + block_size), + self.mask_token_id, + dtype=torch.long, + device=target.device, + ) + position_ids = torch.arange( + output_ids.shape[1], device=target.device + ).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = sample( + output.logits, temperature + ) + target_hidden = extract_context_feature( + output.hidden_states, self.target_layer_ids + ) + + # Decode stage + acceptance_lengths = [] + start = input_ids.shape[1] + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head( + self( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[ + :, past_key_values_draft.get_seq_length() : start + block_size + ], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size + 1 :, :] + ) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = ( + (block_output_ids[:, 1:] == posterior[:, :-1]) + .cumprod(dim=1) + .sum(dim=1)[0] + .item() + ) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ + :, : acceptance_length + 1 + ] + output_ids[:, start + acceptance_length + 1] = posterior[ + :, acceptance_length + ] + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = extract_context_feature( + output.hidden_states, self.target_layer_ids + )[:, : acceptance_length + 1, :] + acceptance_lengths.append(acceptance_length + 1) + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] + for stop_token_id in stop_token_ids + ): + break + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != self.mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin( + output_ids[0][num_input_tokens:], stop_token_ids + ).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[ + :, : num_input_tokens + stop_token_indices[0] + 1 + ] + + return output_ids diff --git a/progress/SpecForge/specforge/modeling/draft/dflash_lora.py b/progress/SpecForge/specforge/modeling/draft/dflash_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0385d0a90aa6ed3464cd528ffecbf3cceab6b5d0 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/dflash_lora.py @@ -0,0 +1,145 @@ +"""DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation.""" + +from typing import Optional + +import torch +import torch.nn as nn +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoModelForCausalLM, AutoConfig + + +class DFlashLoRADraftModel(nn.Module): + """ + Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters. + The model learns to predict all tokens in a block in parallel (1-step diffusion), + using a modified DFlash attention mask over the full sequence. + + Attention mask design: + - context token i: standard causal (attends to j <= i) + - block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional) + """ + + def __init__( + self, + base_model: nn.Module, + block_size: int, + mask_token_id: int, + ): + super().__init__() + self.model = base_model + self.block_size = block_size + self.mask_token_id = mask_token_id + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + lora_rank: int = 16, + lora_alpha: int = 32, + lora_dropout: float = 0.05, + lora_target_modules: Optional[list] = None, + block_size: int = 16, + mask_token_id: int = 151669, + torch_dtype: torch.dtype = torch.bfloat16, + device_map: str = "cuda", + trust_remote_code: bool = False, + attn_implementation: str = "sdpa", + **kwargs, + ) -> "DFlashLoRADraftModel": + """ + attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'. + 'flex_attention' uses torch BlockMask — zero extra memory for attention masks. + Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks. + """ + if lora_target_modules is None: + lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + + base_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + **kwargs, + ) + base_model = base_model.cuda() + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=lora_target_modules, + bias="none", + ) + base_model = get_peft_model(base_model, lora_config) + # Cast LoRA parameters to match base model dtype for FSDP compatibility + for param in base_model.parameters(): + if param.requires_grad and param.dtype != torch_dtype: + param.data = param.data.to(torch_dtype) + base_model.print_trainable_parameters() + + return cls( + base_model=base_model, + block_size=block_size, + mask_token_id=mask_token_id, + ) + + def gradient_checkpointing_enable(self, **kwargs): + self.model.gradient_checkpointing_enable(**kwargs) + + def parameters(self, *args, **kwargs): + return self.model.parameters(*args, **kwargs) + + def named_parameters(self, *args, **kwargs): + return self.model.named_parameters(*args, **kwargs) + + def train(self, mode=True): + self.model.train(mode) + return self + + def eval(self): + self.model.eval() + return self + + def save_pretrained(self, save_dir: str, **kwargs): + """Save only the LoRA adapter weights.""" + self.model.save_pretrained(save_dir, **kwargs) + + def get_lm_head(self) -> nn.Module: + """Return a reference to the lm_head through the PEFT model hierarchy.""" + base_model = self.model.get_base_model() + return base_model.lm_head + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_hidden_states: bool = False, + ): + """ + Forward pass through the LoRA-adapted model. + + Args: + input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs) + attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D + additive) or a BlockMask (for flex_attention). + position_ids: [bsz, seq_len] + output_hidden_states: if True, return last hidden state instead of logits. + Used for chunked cross-entropy loss to avoid materializing full logits. + + Returns: + logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or + hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True. + """ + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + if output_hidden_states: + return outputs.hidden_states[-1] + return outputs.logits diff --git a/progress/SpecForge/specforge/modeling/draft/flex_attention.py b/progress/SpecForge/specforge/modeling/draft/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..50ca5f54dc658106c22d7a8a95553bf346b33525 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/flex_attention.py @@ -0,0 +1,127 @@ +import torch +import torch._dynamo as dynamo +from torch.nn.attention.flex_attention import ( + create_block_mask, + flex_attention, + or_masks, +) +from transformers.utils import is_torchdynamo_compiling + +dynamo.config.recompile_limit = 64 + + +# Reference Implementation https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py +class WrappedFlexAttention: + """ + We are doing a singleton class so that flex attention is compiled once when it's first called. + """ + + _instance = None + _is_flex_compiled = False + _compiled_flex_attention = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + # Create a new instance if one doesn't already exist + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + """ + Initialize or update the singleton instance. + """ + if not self._is_flex_compiled: + # Enable dynamic shapes to handle different input sizes + self._compiled_flex_attention = torch.compile( + flex_attention, + # mode="max-autotune-no-cudagraphs", + ) + self._is_flex_compiled = True + + def __call__(self): + return self._compiled_flex_attention + + +def compile_friendly_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs, +) -> torch.Tensor: + # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention + # Do not use compiled version if already compiling forward (it raises issues) + flex_attention_compiled = ( + WrappedFlexAttention()() if not is_torchdynamo_compiling() else flex_attention + ) + return flex_attention_compiled( + query, + key, + value, + **kwargs, + ) + + +class WrappedCreateBlockMask: + _instance = None + _is_create_block_mask_compiled = False + _compiled_create_block_mask = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + if not self._is_create_block_mask_compiled: + self._compiled_create_block_mask = torch.compile(create_block_mask) + self._is_create_block_mask_compiled = True + + def __call__(self): + return self._compiled_create_block_mask + + +def compile_friendly_create_block_mask( + mask_mod, + B, + H, + Q_LEN, + KV_LEN, + device, +): + create_block_mask_compiled = ( + WrappedCreateBlockMask()() + if not is_torchdynamo_compiling() + else create_block_mask + ) + return create_block_mask_compiled( + mask_mod, + B, + H, + Q_LEN, + KV_LEN, + device, + ) + + +def generate_eagle3_mask( + seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, lck: int = 0 +): + + def causal_mask(b, h, q_idx, kv_idx): + # Causal will keep shrinking by 1 diagnol due to appended suffix + # Shirnk the causal by diagnol + causal_mask = q_idx >= kv_idx + padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b]) + return causal_mask & padding_mask + + def suffix_mask(b, h, q_idx, kv_idx): + suffix_mask = kv_idx >= Q_LEN + padding_mask = kv_idx % Q_LEN < seq_lengths[b] + diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0 + return suffix_mask & padding_mask & diagnol_mask + + mask_mod = or_masks(causal_mask, suffix_mask) + mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_lck_{lck}" + return mask_mod diff --git a/progress/SpecForge/specforge/modeling/draft/llama3_eagle.py b/progress/SpecForge/specforge/modeling/draft/llama3_eagle.py new file mode 100644 index 0000000000000000000000000000000000000000..4a183307281b1868569c025bf2b297ac63fa6446 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/draft/llama3_eagle.py @@ -0,0 +1,1429 @@ +import math +import warnings +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.models.llama.configuration_llama import LlamaConfig +from yunchang.comm import SeqAllToAll4D + +from specforge.modeling.draft.flex_attention import ( + compile_friendly_create_block_mask, + compile_friendly_flex_attention, + generate_eagle3_mask, +) +from specforge.utils import print_with_rank + +from ...distributed import get_sp_ring_group, get_sp_ulysses_group +from ...layers.ring import ring_flash_attn_func +from .base import Eagle3DraftModel + +try: + from flash_attn import flash_attn_func +except ImportError: + warnings.warn( + "flash_attn is not found, falling back to flex_attention. " + "Please install flash_attn if you want to use the flash attention backend." + ) + flash_attn_func = None + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@torch.compile(dynamic=True) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=None, + low_freq_factor=None, + high_freq_factor=None, + orig_max_position=None, + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + # Llama3 style rotary embedding frequency scaling + if all( + v is not None + for v in [ + scaling_factor, + low_freq_factor, + high_freq_factor, + orig_max_position, + ] + ): + print_with_rank( + f"Using Llama3 style rotary embedding with scaling_factor={scaling_factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, orig_max_position={orig_max_position}" + ) + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + + low_freq_wavelen = orig_max_position / low_freq_factor + high_freq_wavelen = orig_max_position / high_freq_factor + wave_len = 2 * math.pi / inv_freq + + if low_freq_factor != high_freq_factor: + smooth = (orig_max_position / wave_len - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + else: + smooth = 0 + + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freq, + torch.where( + wave_len > low_freq_wavelen, + inv_freq / self.scaling_factor, + (1 - smooth) * inv_freq / self.scaling_factor + smooth * inv_freq, + ), + ) + inv_freq = new_freqs + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings + 20, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + @torch.compile(dynamic=True) + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len and seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaMutiRotaryEmbedding(LlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__(dim, max_position_embeddings, base, device) + self.scaling_factor = scaling_factor + + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.scaling_factor + sin = emb.sin() * self.scaling_factor + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / ( + max_val - min_val + ) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", + (emb.cos() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + self.register_buffer( + "sin_cached", + (emb.sin() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if hasattr(config, "head_dim"): + self.head_dim = config.head_dim + else: + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + self.q_proj = nn.Linear( + self.hidden_size * 2, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=getattr(self.config, "rope_theta", 10000), + ) + else: + rope_scaling = self.config.rope_scaling + + def rope_get(key, default=None): + if isinstance(rope_scaling, dict): + return rope_scaling.get(key, default) + return getattr(rope_scaling, key, default) + + scaling_type = rope_get("rope_type", rope_get("type")) + scaling_factor = rope_get("factor") + + if scaling_type == "linear": + if scaling_factor is None: + raise ValueError( + "Linear RoPE scaling requires 'factor' in rope_scaling config." + ) + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + if scaling_factor is None: + raise ValueError( + "Dynamic RoPE scaling requires 'factor' in rope_scaling config." + ) + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "llama3": + # for nv type + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=getattr(self.config, "rope_theta", 10000), + scaling_factor=( + scaling_factor if scaling_factor is not None else 1.0 + ), + low_freq_factor=rope_get("low_freq_factor"), + high_freq_factor=rope_get("high_freq_factor"), + orig_max_position=rope_get("original_max_position_embeddings"), + ) + elif scaling_type == "mrope": + self.rotary_emb = LlamaMutiRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=rope_get( + "original_max_position_embeddings" + ), + scaling_factor=scaling_factor, + beta_fast=rope_get("beta_fast"), + beta_slow=rope_get("beta_slow"), + mscale=rope_get("mscale"), + mscale_all_dim=rope_get("mscale_all_dim"), + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if cache_hidden is None: + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=0.0, + ) + + else: + lck = len(cache_hidden[0]) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + + k0 = cache_k[0] + v0 = cache_v[0] + + # causal + attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + lck = len(cache_k) + + attn_weights = attn_weights + attention_mask + + for i in range(1, lck): + ki = cache_k[i] + qi = query_states + kiq = ki + + attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim) + attn_weights = torch.cat( + (attn_weights, attn_weightsi[..., None]), dim=-1 + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights0 = attn_weights[..., :q_len] + + attn_output = torch.matmul(attn_weights0, v0) + + for i in range(1, lck): + vi = cache_v[i] + attn_weightsi = attn_weights[..., q_len + i - 1] + attn_outputi = attn_weightsi[..., None] * vi + attn_output = attn_output + attn_outputi + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaFlexAttention(LlamaAttention): + """ + Attention layer implemented with flex attention. We keep the parameters consistent with LlamaAttention. + The used parameters are: + - hidden_states: input hidden states + - attention_mask: attention mask not expanded, straight from data loader. + - position_ids: position ids + - past_key_values: dynamic cache used for storing past key and value states. + """ + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + lck = past_seen_tokens // q_len + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + # Keep positions ids aligned when padding so the KV cache is unaffected. + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck + ) + + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + q_len, device=hidden_states.device + ) + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + + key_cache, value_cache = past_key_values.update( + key_states, + value_states, + layer_idx=0, # TODO: support multiple layers + cache_kwargs=cache_kwargs, + ) + + seq_lengths = attention_mask.sum(dim=-1) + # Shrink the attention mask to align with the padding to the right. + # This is equivalent to the shrinking logic in eagle3.py + seq_lengths -= lck + # TODO: Remove the usage of uncompiled create_block_mask after + # https://github.com/pytorch/pytorch/issues/160018 + if q_len <= 128: + create_block_mask_func = create_block_mask + flex_attention_func = flex_attention + else: + create_block_mask_func = compile_friendly_create_block_mask + flex_attention_func = compile_friendly_flex_attention + + block_mask = create_block_mask_func( + mask_mod=generate_eagle3_mask( + seq_lengths=seq_lengths, + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + lck=lck, + ), + B=bsz, + H=1, # Rely on broadcast + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + device=query_states.device, + ) + attn_output = flex_attention_func( + query=query_states, + key=key_cache.contiguous(), + value=value_cache.contiguous(), + block_mask=block_mask, + enable_gqa=True, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + attn_output = self.o_proj(attn_output) + return attn_output + + +class LlamaFlashAttention(LlamaAttention): + """ + Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention. + The used parameters are: + - hidden_states: input hidden states + - position_ids: position ids + - cache_hidden: manual cache used for storing past key and value states + """ + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + + lck = 0 if cache_hidden is None else len(cache_hidden[0]) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + unsqueeze_dim=2, + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) + + if cache_hidden is not None: + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + else: + cache_k = [key_states] + cache_v = [value_states] + + k0 = cache_k[0] + v0 = cache_v[0] + + assert ( + flash_attn_func is not None + ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + attn_output, lse, _ = flash_attn_func( + query_states, + k0, + v0, + dropout_p=0.0, + softmax_scale=1.0 / math.sqrt(self.head_dim), + causal=True, + return_attn_probs=True, + ) + lse = lse.transpose(1, 2) + + lck = len(cache_k) + if lck > 1: + q_shape_expanded = ( + bsz, + q_len, + self.num_key_value_heads, + self.num_key_value_groups, + self.head_dim, + ) + attn_outputs = [attn_output.view(q_shape_expanded)] + lses = [lse.view(q_shape_expanded[:-1])] + + for i in range(1, lck): + ki = cache_k[i].unsqueeze(-2) + qi = query_states.view(q_shape_expanded) + vi = cache_v[i].unsqueeze(-2) + + attn_outputs.append(vi) + lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim)) + + lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1) + attn_output = sum( + attn_outputi * torch.exp(lsei - lse).unsqueeze(-1) + for attn_outputi, lsei in zip(attn_outputs, lses) + ) + # lse is fp32, downcast attn_output back + attn_output = attn_output.to(self.o_proj.weight.dtype) + + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaUSPFlashAttention(LlamaAttention): + """ + LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging. + """ + + def __init__(self, config): + super().__init__(config) + assert ( + dist.is_initialized() + ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + raise NotImplementedError( + f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." + ) + self.ring_pg = get_sp_ring_group() + self.ulysses_pg = get_sp_ulysses_group() + self.sp_ring_degree = torch.distributed.get_world_size(self.ring_pg) + self.sp_ulysses_degree = torch.distributed.get_world_size(self.ulysses_pg) + self.ring_rank = torch.distributed.get_rank(self.ring_pg) + + self.scatter_idx = 2 + self.gather_idx = 1 + self.use_sync = False + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + local_q_len = q_len + + # ============================================================= + # 1. Projections & Ulysses Scatter + # ============================================================= + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = SeqAllToAll4D.apply( + self.ulysses_pg, + query_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + key_states = self.k_proj(hidden_states) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + key_states = SeqAllToAll4D.apply( + self.ulysses_pg, + key_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + value_states = self.v_proj(hidden_states) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + value_states = SeqAllToAll4D.apply( + self.ulysses_pg, + value_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + current_q_len = query_states.shape[1] + local_num_heads = query_states.shape[2] + + # Global length calculation (for RoPE) + global_q_len = q_len * self.sp_ring_degree * self.sp_ulysses_degree + # ============================================================= + # 2. RoPE & Cache Management + # ============================================================= + lck = 0 if cache_hidden is None else len(cache_hidden[0]) + + cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) + + # Update Cache (Eagle3 Logic: Cache is a list of tensors for tree branches) + if cache_hidden is not None: + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + else: + cache_k = [key_states] + cache_v = [value_states] + + # ============================================================= + # 3. Hybrid Attention Computation + # ============================================================= + + # 3.1 Main Sequence (Ring Attention) + out_ring, lse_ring, _ = ring_flash_attn_func( + query_states, + cache_k[0], + cache_v[0], + dropout_p=0.0, + softmax_scale=1.0 / math.sqrt(self.head_dim), + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + group=self.ring_pg, + ) + + if lse_ring.dim() == 3 and lse_ring.shape[1] == local_num_heads: + acc_lse = lse_ring.transpose(1, 2).contiguous() # -> [B, S, H] + else: + acc_lse = lse_ring + + assert ( + acc_lse.shape[1] == current_q_len + ), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + + acc_out = out_ring + + # 3.2 Extras Branches (Eagle3 Point-wise Update) + if len(cache_k) > 1: + num_kv_heads_local = cache_k[0].shape[2] + local_groups = local_num_heads // num_kv_heads_local + + q_shape_expanded = ( + bsz, + current_q_len, + num_kv_heads_local, + local_groups, + self.head_dim, + ) + qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D] + + for i in range(1, len(cache_k)): + ki = cache_k[i] # [B, S, KV, D] + vi = cache_v[i] # [B, S, KV, D] + + ki_expanded = ki.unsqueeze(-2) # [B, S, KV, 1, D] + + # Dot Product: [B, S, KV, G] + score_i = (qi_reshaped * ki_expanded).sum(-1) / math.sqrt(self.head_dim) + + # Flatten back to [B, S, H_local] + step_lse = score_i.view(bsz, current_q_len, -1) + + vi_expanded = vi.unsqueeze(-2) + step_out = vi_expanded.expand(q_shape_expanded).reshape(acc_out.shape) + + # Online Softmax Update + new_lse = torch.logaddexp(acc_lse, step_lse) + + acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze( + -1 + ) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1) + + acc_lse = new_lse + + attn_output = acc_out.to(query_states.dtype) + + # ============================================================= + # 4. Ulysses Gather & Output Projection + # ============================================================= + attn_output = SeqAllToAll4D.apply( + self.ulysses_pg, + attn_output, + self.gather_idx, # Scatter idx: 1 (Seq) + self.scatter_idx, # Gather idx: 2 (Heads) + self.use_sync, + ) + + attn_output = attn_output.reshape( + bsz, local_q_len, self.head_dim * self.num_heads + ) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [ + F.linear(x, gate_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + up_proj = torch.cat( + [ + F.linear(x, up_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + @torch.compile(dynamic=True) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config, attention_backend: str = "sdpa"): + super().__init__() + self.hidden_size = config.hidden_size + + if attention_backend == "sdpa": + self.self_attn = LlamaAttention(config=config) + elif attention_backend == "flex_attention": + print_with_rank("Using flex attention on draft model training!") + self.self_attn = LlamaFlexAttention(config=config) + elif attention_backend == "fa": + self.self_attn = LlamaFlashAttention(config=config) + elif attention_backend == "usp": + self.self_attn = LlamaUSPFlashAttention(config=config) + else: + raise ValueError(f"Unknown attention backend {attention_backend}") + + self.attention_backend = attention_backend + self.mlp = LlamaMLP(config) + # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # if self.index!=0: + + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + input_emb: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: List[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.hidden_norm(hidden_states) + input_emb = self.input_layernorm(input_emb) + + hidden_states = torch.cat((input_emb, hidden_states), dim=-1) + # Self Attention + hidden_states = self.self_attn( + cache_hidden=cache_hidden, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs = (hidden_states, return_hidden) + return hidden_states + + +class LlamaForCausalLMEagle3(Eagle3DraftModel): + + config_class = LlamaConfig + + def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: + super().__init__(config) + self.config = config + self.quant_config = quant_config + + self.vocab_size = config.vocab_size + self.draft_vocab_size = config.draft_vocab_size + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id + ) + self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + + if hasattr(config, "target_hidden_size"): + self.fc = torch.nn.Linear( + config.target_hidden_size * 3, config.hidden_size, bias=False + ) + else: + self.fc = torch.nn.Linear( + config.hidden_size * 3, config.hidden_size, bias=False + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear( + config.hidden_size, config.draft_vocab_size, bias=False + ) + + # create vocab buffers + t2d = torch.ones(self.vocab_size, dtype=torch.bool) + d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ttt_length: int = 1, + ): + """ + Arguments: + hidden_states (`torch.FloatTensor`): input to the layer, cat low, mid high hidden_states of shape `(batch, seq_len, hidden_states * 3)` + input_ids (`torch.LongTensor`): input ids of shape `(batch, seq_len)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)` + """ + if ttt_length == 1: + print_with_rank("using ttt_length 1, no need to cache hidden states") + cache_hidden = None + else: + print_with_rank(f"using ttt_length {ttt_length}, caching hidden states") + cache_hidden = [[], []] + + batch_size, seq_length, _ = hidden_states.size() + + # make position ids + device = hidden_states.device + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + # make attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, 0 + ) + + # fc + hidden_states = self.fc(hidden_states) + hidden_states = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + ) + + # norm + hidden_states = self.norm(hidden_states) + + return hidden_states + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + # eagle 3 requires hidden states from 3 layers + assert hidden_states.size(-1) == self.config.hidden_size * 3 + return self.fc(hidden_states) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + norm_hidden_states = self.norm(hidden_states) + return self.lm_head(norm_hidden_states) + + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Cache] = None, + use_cache: bool = True, + ) -> torch.Tensor: + return self.midlayer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=False, + ) diff --git a/progress/SpecForge/specforge/modeling/target/__init__.py b/progress/SpecForge/specforge/modeling/target/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f70b3b740d055ae72dfebacc5b9f7434f3eed0e --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/__init__.py @@ -0,0 +1,17 @@ +from .eagle3_target_model import ( + CustomEagle3TargetModel, + Eagle3TargetModel, + HFEagle3TargetModel, + SGLangEagle3TargetModel, + get_eagle3_target_model, +) +from .target_head import TargetHead + +__all__ = [ + "Eagle3TargetModel", + "SGLangEagle3TargetModel", + "HFEagle3TargetModel", + "CustomEagle3TargetModel", + "get_eagle3_target_model", + "TargetHead", +] diff --git a/progress/SpecForge/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c196f0cfc950b32f54f5172bc3766da22fecb93c Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06508afe587e90390ce26182901057057f6b2a3f Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d105399dec43e0a1f118238b5509f9a3768df984 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__init__.py b/progress/SpecForge/specforge/modeling/target/custom_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5465d15a8e84c788c43e5df709c33f4efb0bd43d --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/__init__.py @@ -0,0 +1,17 @@ +from .gpt_oss import GptOssForCausalLM +from .llama import LlamaForCausalLM +from .llama4 import Llama4ForCausalLM +from .phi3 import Phi3ForCausalLM +from .qwen2 import Qwen2ForCausalLM +from .qwen3 import Qwen3ForCausalLM +from .qwen3_moe import Qwen3MoeForCausalLM + +__all__ = [ + "GptOssForCausalLM", + "LlamaForCausalLM", + "Llama4ForCausalLM", + "Phi3ForCausalLM", + "Qwen2ForCausalLM", + "Qwen3ForCausalLM", + "Qwen3MoeForCausalLM", +] diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..581a99907094e43a77487c8af01c46faa710c46d Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb0edbea283ad0f19dc0c016df744a29dcc6efbc Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c69e4477481f934762f3a013498a71676ded2b72 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1cbe4e301752d78c8b985c03b847837423fa3fb Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef063084ae6bcd90dde231352cc76523989abd9a Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cfdbaf413b4ca7fdde5fce7421d255e1eb6d19e Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e020d74cd9ed85b0403783d4ffcba5bc2cb5879 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a916b4b79771f1f89e1f9f476d968248c564b4ae Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/gpt_oss.py b/progress/SpecForge/specforge/modeling/target/custom_backend/gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b4a79723f48dd2b3e7db077e030f88288a3145 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/gpt_oss.py @@ -0,0 +1,879 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations.hub_kernels import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group, shard_tensor +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + + +class GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + # apply tp + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.expert_dim_per_shard = self.expert_dim // self.tp_size + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, self.hidden_size, 2 * self.expert_dim_per_shard + ) + ) + self.gate_up_proj_bias = nn.Parameter( + torch.empty(self.num_experts, 2 * self.expert_dim_per_shard) + ) + self.down_proj = nn.Parameter( + torch.empty((self.num_experts, self.expert_dim_per_shard, self.hidden_size)) + ) + self.down_proj_bias = nn.Parameter( + torch.empty(self.num_experts, self.hidden_size) + ) + + self.alpha = 1.702 + self.limit = 7.0 + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "down_proj" in state_dict: + # columnwise splitting + value = state_dict["down_proj"] + state_dict["down_proj"] = shard_tensor(value, self.tp_group, 1) + + if "down_proj_bias" in state_dict: + value = state_dict["down_proj_bias"] + if dist.get_rank(self.tp_group) != 0: + value.zero_() + + if "gate_up_proj_bias" in state_dict: + value = state_dict["gate_up_proj_bias"] + state_dict["gate_up_proj_bias"] = shard_tensor(value, self.tp_group, 1) + + if "gate_up_proj" in state_dict: + value = state_dict["gate_up_proj"] + gate, up = value[..., ::2], value[..., 1::2] + gate = shard_tensor(gate, self.tp_group, 2) + up = shard_tensor(up, self.tp_group, 2) + new_value = torch.zeros_like(self.gate_up_proj, device=value.device) + new_value[..., ::2] = gate + new_value[..., 1::2] = up + state_dict["gate_up_proj"] = new_value + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """ + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape( + -1, self.hidden_size + ) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=num_experts + ) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence lenght to get which experts + # are hit this time around + expert_hitted = torch.greater( + expert_mask.sum(dim=(-1, -2)), 0 + ).nonzero() + for expert_idx in expert_hitted[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = ( + current_state @ self.gate_up_proj[expert_idx] + + self.gate_up_proj_bias[expert_idx] + ) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = ( + gated_output @ self.down_proj[expert_idx] + + self.down_proj_bias[expert_idx] + ) + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_( + 0, token_idx, weighted_output.to(hidden_states.dtype) + ) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = ( + torch.bmm(hidden_states, self.gate_up_proj) + + self.gate_up_proj_bias[..., None, :] + ) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view( + num_experts, batch_size, -1, self.hidden_size + ) + next_states = ( + next_states + * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[ + ..., None + ] + ) + dist.all_reduce(next_states, op=dist.ReduceOp.SUM, group=self.tp_group) + + next_states = next_states.sum(dim=0) + return next_states + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear( + hidden_states, self.weight, self.bias + ) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=router_top_value.dtype + ) + router_scores = torch.zeros_like(router_logits).scatter_( + 1, router_indices, router_top_value + ) + return router_scores, router_indices + + +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") +class GptOssMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) + + def forward(self, hidden_states): + router_scores, router_indices = self.router( + hidden_states + ) # (num_experts, seq_len) + routed_out = self.experts( + hidden_states, router_indices=router_indices, routing_weights=router_scores + ) + return routed_out, router_scores + + +class GptOssRotaryEmbedding(nn.Module): + def __init__(self, config: GptOssConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = freqs + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(x.dtype), sin.to(x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + query.shape[0], -1, query.shape[-2], -1 + ) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class GptOssAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + # self.q_proj = nn.Linear( + # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + # ) + # self.k_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.v_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.o_proj = nn.Linear( + # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + # ) + # self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + # self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.num_attention_heads_per_shard = config.num_attention_heads // self.tp_size + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + self.sinks = nn.Parameter(torch.empty(self.num_attention_heads_per_shard)) + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "sinks" in state_dict: + value = state_dict["sinks"] + state_dict["sinks"] = shard_tensor(value, self.tp_group, 0) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class GptOssDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GptOssRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class GptOssPreTrainedModel(PreTrainedModel): + config: GptOssConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GptOssDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = False + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_flash_attention = False + _supports_flex_attention = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GptOssRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, GptOssExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, GptOssAttention): + module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) + + +@auto_docstring +class GptOssModel(GptOssPreTrainedModel): + _no_split_modules = ["GptOssDecoderLayer"] + + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + GptOssDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GptOssRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GptOssModel(config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"] diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/llama.py b/progress/SpecForge/specforge/modeling/target/custom_backend/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..04a3f6c9bd40b684e5d287ddf4477ea50cfa68c8 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/llama.py @@ -0,0 +1,460 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, logging +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class TensorParallelLlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=config.mlp_bias + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class TensorParallelLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # self.q_proj = nn.Linear( + # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + # ) + # self.k_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.v_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.o_proj = nn.Linear( + # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + # ) + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class TensorParallelLlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = TensorParallelLlamaAttention( + config=config, layer_idx=layer_idx + ) + + self.mlp = TensorParallelLlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaPreTrainedModel(PreTrainedModel): + config: LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TensorParallelLlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + + +class LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + TensorParallelLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + ) + + +__all__ = [ + "LlamaForCausalLM", + "LlamaModel", +] diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/llama4.py b/progress/SpecForge/specforge/modeling/target/custom_backend/llama4.py new file mode 100644 index 0000000000000000000000000000000000000000..22f807daed1f6a1b1535745afb95a4feee7e3d0b --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/llama4.py @@ -0,0 +1,613 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations.hub_kernels import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama4.configuration_llama4 import ( + Llama4Config, + Llama4TextConfig, +) +from transformers.models.llama4.modeling_llama4 import ( + Llama4Router, + Llama4TextL2Norm, + Llama4TextRMSNorm, + Llama4TextRotaryEmbedding, + Llama4VisionModel, + apply_rotary_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs + +# [MODIFIED] Import from transformers library +from specforge.distributed import get_tp_group, shard_tensor +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Llama4TextExperts(nn.Module): + def __init__(self, config: Llama4TextConfig): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.expert_dim_per_shard = self.expert_dim // self.tp_size + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, self.hidden_size, 2 * self.expert_dim_per_shard + ) + ) + self.down_proj = nn.Parameter( + torch.empty((self.num_experts, self.expert_dim_per_shard, self.hidden_size)) + ) + self.act_fn = ACT2FN[config.hidden_act] + + # deal with weight loading and sharding + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "down_proj" in state_dict: + value = state_dict["down_proj"] + state_dict["down_proj"] = shard_tensor(value, self.tp_group, 1) + + if "gate_up_proj" in state_dict: + value = state_dict["gate_up_proj"] + gate, up = value.chunk(2, dim=-1) + gate = shard_tensor(gate, self.tp_group, -1) + up = shard_tensor(up, self.tp_group, -1) + value = torch.cat((gate, up), dim=-1) + state_dict["gate_up_proj"] = value + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view( + self.gate_up_proj.shape[0], -1, self.hidden_size + ) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + dist.all_reduce(next_states, op=dist.ReduceOp.SUM, group=self.tp_group) + next_states = next_states.view(-1, self.hidden_size) + return next_states + + +class Llama4TextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + + if intermediate_size is None: + intermediate_size = config.intermediate_size + + self.config = config + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + intermediate_size, config.hidden_size, bias=False + ) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) + out = self.down_proj(down_proj) + dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.tp_group) + return out + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.use_rope = config.no_rope_layers[layer_idx] + + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://huggingface.co/papers/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log1p( + torch.floor((cache_position.float() + 1.0) / self.floor_scale) + ) + * self.attn_scale + + 1.0 + ) + attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand( + (*input_shape, 1, 1) + ) # batch size > 1 + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("Llama4TextMoe") +class Llama4TextMoe(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = Llama4TextExperts(config) + self.router = Llama4Router(config) + self.shared_expert = Llama4TextMLP(config) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + routed_in = hidden_states.repeat(router_scores.shape[1], 1) + routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + out.add_( + routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum( + dim=0 + ) + ) + return out, router_logits + + +class Llama4TextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Llama4TextAttention(config, layer_idx) + self.is_moe_layer = layer_idx in config.moe_layers + if self.is_moe_layer: # the 128E model interleaves dense / sparse + self.feed_forward = Llama4TextMoe(config) + else: + self.feed_forward = Llama4TextMLP( + config, intermediate_size=config.intermediate_size_mlp + ) + + self.input_layernorm = Llama4TextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Llama4TextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attention_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + attention_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + if self.is_moe_layer: + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states.view(residual.shape) + return hidden_states + + +@auto_docstring +class Llama4PreTrainedModel(PreTrainedModel): + config: Llama4Config + supports_gradient_checkpointing = True + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Llama4TextRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Llama4TextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + elif isinstance(module, Llama4VisionModel): + module.class_embedding.data.normal_(std=module.scale) + module.positional_embedding_vlm.data.normal_(std=module.scale) + + +@auto_docstring +class Llama4TextModel(Llama4PreTrainedModel): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "model" + config: Llama4TextConfig + _can_record_outputs = {} + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Llama4TextDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens( + input_ids.to(self.embed_tokens.weight.device) + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "chunked_attention": create_chunked_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + freq_cis = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "language_model" + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config: Llama4TextConfig + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.model = Llama4TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/phi3.py b/progress/SpecForge/specforge/modeling/target/custom_backend/phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..2515701f90f8c58cd164fc3e345549877212f379 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/phi3.py @@ -0,0 +1,495 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers import Phi3Config +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.phi3.modeling_phi3 import ( + Phi3RMSNorm, + Phi3RotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # Add TP support + self.tp_group = get_tp_group() + + self.gate_up_proj = ColumnParallelLinear( + config.hidden_size, + 2 * config.intermediate_size, + bias=False, + layout_type="gate_up", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, config.hidden_size, bias=False + ) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + down_proj = self.down_proj(up_states) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support + self.tp_group = get_tp_group() + tp_size = dist.get_world_size(self.tp_group) + + # Adjust head counts for TP + self.num_attention_heads_per_rank = config.num_attention_heads // tp_size + self.num_key_value_heads_per_rank = config.num_key_value_heads // tp_size + + # ColumnParallel splits the full QKV output across ranks + op_size = config.num_attention_heads * self.head_dim + 2 * ( + config.num_key_value_heads * self.head_dim + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) + self.qkv_proj = ColumnParallelLinear( + config.hidden_size, op_size, bias=False, layout_type="merged_qkv" + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_attention_heads_per_rank * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[ + ..., + query_pos : query_pos + self.num_key_value_heads_per_rank * self.head_dim, + ] + value_states = qkv[ + ..., query_pos + self.num_key_value_heads_per_rank * self.head_dim : + ] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Phi3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout( + hidden_states + ) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout( + hidden_states + ) # main diff with Llama + return hidden_states + + +@auto_docstring +class Phi3PreTrainedModel(PreTrainedModel): + config: Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + _version = "0.0.5" + + +@auto_docstring +class Phi3Model(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Phi3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +@auto_docstring +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/qwen2.py b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ea42f95b4ca6b28bc17584b616b909703f3293 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen2.py @@ -0,0 +1,829 @@ +# coding=utf-8 +# Copyright 2025 The Qwen2 and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) + +# [MODIFIED] Import from distributed library +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # distributed linear layers + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=True, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=True, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=True, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + ) + + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + "The `past_key_values` should be either a `Cache` object or `None`." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + layers_to_output_hidden_states = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + layers_to_output_hidden_states=layers_to_output_hidden_states, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to( + logits.device, torch.int32 + ) + token_indices = torch.arange( + input_ids.shape[-1], device=logits.device, dtype=torch.int32 + ) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), last_non_pad_token + ] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + pooled_logits=pooled_logits, + config=self.config, + ) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function( + start_logits, end_logits, start_positions, end_positions, **kwargs + ) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3.py b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0df91f03a3fd74be205cc685ad864f73fd35e8 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2025 Qwen Team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers import Qwen3Config +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3RMSNorm, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, logging + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Add TP support + self.tp_group = get_tp_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.total_num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support + self.tp_group = get_tp_group() + + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + # Sliding window logic is kept as is, assuming it's handled in config.layer_types + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm( + self.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + + +class Qwen3RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3PreTrainedModel(PreTrainedModel): + config_class = Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen3Model(Qwen3PreTrainedModel): + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3_moe.py b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..61f1880f6d3f92ab112388bb7ec991e8535f8600 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/custom_backend/qwen3_moe.py @@ -0,0 +1,889 @@ +# coding=utf-8 +# Copyright 2025 Qwen Team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen3MoeConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, logging + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support and head calculations + self.tp_group = get_tp_group() + self.tp_size = ( + dist.get_world_size(self.tp_group) if self.tp_group is not None else 1 + ) + self.tp_rank = dist.get_rank(self.tp_group) if self.tp_group is not None else 0 + + # Calculate head distribution for TP + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_key_value_heads + self.num_heads = ( + self.total_num_heads // self.tp_size + ) # this is the number heads per rank + + # Handle KV head replication when tp_size > total_num_kv_heads + if self.tp_size > self.total_num_kv_heads: + # In replication mode, each rank gets 1 KV head (replicated across groups) + self.num_kv_heads = 1 + self.num_kv_head_replicas = self.tp_size // self.total_num_kv_heads + self.num_key_value_groups = ( + self.num_heads // self.num_kv_heads + ) # this is size for expanding kv for gqa + self.kv_head_replicas = True + else: + self.num_kv_heads = self.total_num_kv_heads + self.num_kv_head_replicas = 1 + self.num_key_value_groups = config.num_attention_heads // self.num_kv_heads + self.kv_head_replicas = False + + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + kv_head_replicas=self.kv_head_replicas, + kv_head_idx=self.tp_rank // self.num_kv_head_replicas, + total_num_kv_heads=config.num_key_value_heads, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + kv_head_replicas=self.kv_head_replicas, + kv_head_idx=self.tp_rank // self.num_kv_head_replicas, + total_num_kv_heads=config.num_key_value_heads, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.q_norm = Qwen3MoeRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3MoeRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + self.sliding_window = getattr(config, "sliding_window", None) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm( + self.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + intermediate_size + if intermediate_size is not None + else config.intermediate_size + ) + + # Add TP support + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [ + Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class Qwen3MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3MoePreTrainedModel(PreTrainedModel): + config_class = Qwen3MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3MoeRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen3MoeModel(Qwen3MoePreTrainedModel): + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3MoeDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None and aux_loss != 0: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/progress/SpecForge/specforge/modeling/target/dflash_target_model.py b/progress/SpecForge/specforge/modeling/target/dflash_target_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b9fe0c5f4922d79eca369dc87d761aec0a0032e5 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/dflash_target_model.py @@ -0,0 +1,354 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather +from transformers import AutoModelForCausalLM + +from specforge.distributed import get_tp_device_mesh, get_tp_group +from specforge.utils import padding + +from .sglang_backend import SGLangRunner + + +@dataclass +class DFlashTargetOutput: + hidden_states: torch.Tensor # [batch, seq_len, hidden_size] + input_ids: torch.Tensor # [batch, seq_len] + attention_mask: torch.Tensor # [batch, seq_len] + loss_mask: torch.Tensor # [batch, seq_len] + + +class DFlashTargetModel(ABC): + """ + Abstract base class for DFlash target model backend. + """ + + def __init__(self): + self.capture_layer_ids = None + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "DFlashTargetModel": + """Initialize the target model backend.""" + + @abstractmethod + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + """Generate context hidden states for DFlash training.""" + + def set_capture_layers(self, layer_ids: List[int]) -> None: + """Set which layers' hidden states to capture.""" + self.capture_layer_ids = layer_ids + + +class SGLangDFlashTargetModel(DFlashTargetModel): + def __init__(self, model_runner: SGLangRunner): + super().__init__() + self.model_runner = model_runner + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "SGLangDFlashTargetModel": + tp_size = dist.get_world_size(get_tp_group()) + server_args = ServerArgs( + model_path=pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + dtype=torch_dtype, + enable_return_hidden_states=True, # Critical for DFlash + disable_cuda_graph=True, + tp_size=tp_size, + pp_size=1, + **kwargs, + ) + + tp_rank = dist.get_rank(get_tp_group()) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + model_config = ModelConfig.from_server_args(server_args) + + model_runner = SGLangRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=torch.cuda.current_device(), + tp_rank=dist.get_rank(get_tp_group()), + tp_size=server_args.tp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=server_args.ep_size, + pp_rank=0, + pp_size=1, + server_args=server_args, + nccl_port=None, + ) + return cls(model_runner) + + def set_capture_layers(self, layer_ids: List[int]) -> None: + super().set_capture_layers(layer_ids) + # Note: We need to ensure SGLang supports custom capture layers. + # Eagle3 implementation uses `set_eagle3_layers_to_capture`. + # For DFlash, we might need to rely on `output_hidden_states=True` returning all layers + # and then filtering, OR implementing `set_custom_layers_to_capture` in SGLang patch. + # Assuming we can use the same mechanism or general mechanism if available. + # If SGLang doesn't support selective capture easily, we might get all and select later. + # But for memory efficiency, selective capture is better. + + # Checking Eagle3 implementation again: it calls `model.set_eagle3_layers_to_capture`. + # This implies SGLang model wrapper has this method patched. + # We will try to use a similar approach or assume we get full hidden states. + + # For now, let's assume we capture what's needed. + if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): + self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) + + @torch.no_grad + def _extend(self, reqs): + # Similar to Eagle3 _extend but simplified for just hidden states + cache_params = CacheInitParams( + disable=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + tree_cache = RadixCache(cache_params) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=tree_cache, + model_config=self.model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + + if require_mlp_sync(self.model_runner.server_args): + Scheduler.prepare_mlp_sync_batch_raw( + batch, + dp_size=self.model_runner.server_args.dp_size, + attn_tp_size=1, + tp_group=self.model_runner.tp_group, + get_idle_batch=None, + disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, + spec_algorithm=SpeculativeAlgorithm.NONE, + speculative_num_draft_tokens=None, + require_mlp_tp_gather=require_mlp_tp_gather( + self.model_runner.server_args + ), + disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, + offload_tags=set(), + ) + + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + output, _ = self.model_runner.forward(forward_batch) + + # Eagle3 output has aux_hidden_states. + # We need to check what SGLang returns. Typically it returns 'hidden_states' or 'aux_hidden_states'. + # Assuming it aligns with Eagle3 patch. + + input_lens = [len(req.origin_input_ids) for req in reqs] + + # Split per request + if ( + hasattr(output, "aux_hidden_states") + and output.aux_hidden_states is not None + ): + hidden_states_list = torch.split( + output.aux_hidden_states, input_lens, dim=0 + ) + elif hasattr(output, "hidden_states") and output.hidden_states is not None: + hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0) + else: + raise ValueError("SGLang output does not contain hidden states.") + + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + + return hidden_states_list + + @torch.no_grad() + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + sampling_params = SamplingParams(temperature=0, max_new_tokens=1) + reqs, data_cache = [], [] + + if isinstance(input_ids, torch.Tensor): + input_ids_list = torch.split(input_ids, 1, dim=0) + attn_mask_list = torch.split(attention_mask, 1, dim=0) + loss_mask_list = torch.split(loss_mask, 1, dim=0) + + for idx, (curr_ids, curr_attn, curr_loss) in enumerate( + zip(input_ids_list, attn_mask_list, loss_mask_list) + ): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=curr_ids.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + data_cache.append((curr_ids, curr_attn, curr_loss)) + reqs.append(req) + + hidden_states_list = self._extend(reqs) + + # Stack back to batch + hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) + input_ids = torch.cat([d[0] for d in data_cache], dim=0) + attention_mask = torch.cat([d[1] for d in data_cache], dim=0) + loss_mask = torch.cat([d[2] for d in data_cache], dim=0) + + # Padding might be needed if batching varied lengths (but usually fixed length training) + hidden_states = padding(hidden_states, left=False) + input_ids = padding(input_ids, left=False) + + return DFlashTargetOutput( + hidden_states=hidden_states, + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + + +class HFDFlashTargetModel(DFlashTargetModel): + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = True, + **kwargs, + ) -> "HFDFlashTargetModel": + tp_size = get_tp_group().size() + + if tp_size > 1: + device_kwargs = { + "tp_plan": "auto", + "tp_size": tp_size, + "device_mesh": get_tp_device_mesh(), + } + else: + device_kwargs = { + "device_map": device, + } + + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + output_hidden_states=True, + trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2", + **device_kwargs, + **kwargs, + ).eval() + + return cls(target_model) + + @torch.no_grad() + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Extract selected layers + # outputs.hidden_states is a tuple of (L+1) tensors + # Indices in self.capture_layer_ids correspond to 0-based index of transformer layers. + # outputs.hidden_states[0] is embedding output (usually). + # Typically hidden_states[i+1] is output of layer i. + + offset = 1 + selected = [] + if self.capture_layer_ids is not None: + for idx in self.capture_layer_ids: + selected.append(outputs.hidden_states[idx + offset]) + hidden_states = torch.cat(selected, dim=-1) + else: + # Fallback if no layers specified (maybe return last?) + hidden_states = outputs.hidden_states[-1] + + return DFlashTargetOutput( + hidden_states=hidden_states, + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + + +def get_dflash_target_model( + pretrained_model_name_or_path: str, + backend: str = "sglang", + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> DFlashTargetModel: + if backend == "sglang": + return SGLangDFlashTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "hf": + return HFDFlashTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + else: + raise ValueError(f"Invalid backend: {backend}") diff --git a/progress/SpecForge/specforge/modeling/target/eagle3_target_model.py b/progress/SpecForge/specforge/modeling/target/eagle3_target_model.py new file mode 100644 index 0000000000000000000000000000000000000000..23554414868bb0e36cd6c1e6728bdf671813b03c --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/eagle3_target_model.py @@ -0,0 +1,857 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import sglang.srt.managers.mm_utils as mm_utils +import torch +import torch.distributed as dist +import torch.nn as nn +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + init_mm_embedding_cache, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + Req, + ScheduleBatch, +) +from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather +from transformers import AutoModelForCausalLM + +from specforge.distributed import get_tp_device_mesh, get_tp_group +from specforge.utils import padding + +from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module +from .sglang_backend.utils import LogitsProcessorForEAGLE3 + + +@dataclass +class Eagle3TargetOutput: + hidden_states: torch.Tensor + target: torch.Tensor + loss_mask: torch.Tensor + input_ids: torch.Tensor + attention_mask: torch.Tensor + last_hidden_states: Optional[torch.Tensor] = None + + +class Eagle3TargetModel(ABC): + """ + This offers a layer of abstraction for the target model backend. The user can choose different backends to suit their needs: + 1. SGLang backend: for the mainstream model support with the fastest inference speed + 2. HuggingFace backend: for models that are not supported by SGLang but can be loaded by HuggingFace. + 3. Custom backend: for models with customized architecture and inference plan. + """ + + def __init__(self): + self.aux_hidden_states_layers = None + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "Eagle3TargetModel": + """ + Initialize the target model backend from a pretrained model path. + """ + + @abstractmethod + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + """ + Generate the eagle3 data from the target model. + """ + + def set_aux_hidden_states_layers( + self, aux_hidden_states_layers: Optional[List[int]] = None + ) -> None: + """ + Set the layers to capture the aux hidden states from the target model outputs. + """ + if aux_hidden_states_layers is None: + if hasattr(self.model.config, "num_hidden_layers"): + num_layers = self.model.config.num_hidden_layers + else: + raise ValueError( + f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers" + ) + aux_hidden_states_layers = [ + 1, + num_layers // 2 - 1, + num_layers - 4, + ] + self.aux_hidden_states_layers = aux_hidden_states_layers + assert ( + len(self.aux_hidden_states_layers) == 3 + ), "aux_hidden_states_layers is expected to be 3 layers for EAGLE3" + + +class HFEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "HFEagle3TargetModel": + """ + Initialize the HuggingFace target model backend from a pretrained model path. + """ + tp_size = get_tp_group().size() + + if tp_size > 1: + device_kwargs = { + "tp_plan": "auto", + "tp_size": tp_size, + "device_mesh": get_tp_device_mesh(), + } + else: + device_kwargs = { + "device_map": device, + } + + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + **device_kwargs, + **kwargs, + ) + return cls(target_model) + + def _get_transformer_layers(self): + """ + Helper to find the module list containing the transformer layers. + Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.) + """ + if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): + return self.model.model.layers + elif hasattr(self.model, "layers"): + return self.model.layers + elif hasattr(self.model, "transformer") and hasattr( + self.model.transformer, "h" + ): + return self.model.transformer.h + else: + raise ValueError( + "Could not locate transformer layers in the model architecture to register hooks." + ) + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + """ + Optimized HF backend: + Instead of returning all hidden states (memory heavy), we use forward hooks + to capture only the specific layers required by Eagle3. + """ + captured_states = {} + handles = [] + + def get_hook(layer_idx): + def hook(module, input, output): + # HF outputs for layers are usually tuples (hidden_states, present_key_value, ...) + # We only need the hidden_states (first element) + if isinstance(output, tuple): + hidden = output[0] + else: + hidden = output + captured_states[layer_idx] = hidden + + return hook + + # Locate the transformer layers ModuleList + layers = self._get_transformer_layers() + + target_indices = self.aux_hidden_states_layers + + # Register hooks + for idx in target_indices: + # Ensure index is within bounds + if 0 <= idx < len(layers): + handles.append(layers[idx].register_forward_hook(get_hook(idx))) + else: + raise ValueError( + f"Layer index {idx} out of bounds for model with {len(layers)} layers." + ) + + try: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + output_attentions=False, + output_router_logits=False, + use_cache=False, + ) + target = outputs.logits + finally: + # Always remove hooks to prevent memory leaks or side effects on subsequent calls + for handle in handles: + handle.remove() + + # Verify we captured everything + if len(captured_states) != 3: + raise RuntimeError( + f"Expected to capture 3 layers, but captured {len(captured_states)}" + ) + + # Extract in the correct order + hidden_states0 = captured_states[target_indices[0]] + hidden_states1 = captured_states[target_indices[1]] + hidden_states2 = captured_states[target_indices[2]] + + hidden_states = torch.cat( + (hidden_states0, hidden_states1, hidden_states2), dim=-1 + ) + + # apply pading + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None].to(target.device) + + return Eagle3TargetOutput( + hidden_states=hidden_states, + target=target, + loss_mask=loss_mask, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +class SGLangEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model_runner: SGLangRunner, hf_config=None): + super().__init__() + self.model_runner = model_runner + self.hf_config = hf_config + + # VLM-specific attributes (initialized from hf_config if available) + self._init_vlm_attributes() + + def _init_vlm_attributes(self): + """Initialize VLM-specific attributes from hf_config for models like Qwen2.5-VL""" + if self.hf_config is None: + self.is_vlm = False + return + + # Check if this is a VLM model by looking for vision_config + self.is_vlm = hasattr(self.hf_config, "vision_config") + + if not self.is_vlm: + return + + init_mm_embedding_cache(1024 * 1024 * 512) + # Model type (e.g., "qwen2_5_vl", "qwen2_vl") + self.model_type = getattr(self.hf_config, "model_type", None) + + # Vision config attributes + vision_config = self.hf_config.vision_config + self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self.tokens_per_second = getattr(vision_config, "tokens_per_second", None) + + # Special token IDs from hf_config + self.image_token_id = getattr(self.hf_config, "image_token_id", None) + self.video_token_id = getattr(self.hf_config, "video_token_id", None) + self.vision_start_token_id = getattr( + self.hf_config, "vision_start_token_id", None + ) + self.vision_end_token_id = getattr(self.hf_config, "vision_end_token_id", None) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "SGLangEagle3TargetModel": + tp_size = dist.get_world_size(get_tp_group()) + server_args = ServerArgs( + model_path=pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + dtype=torch_dtype, + enable_return_hidden_states=True, + disable_cuda_graph=True, # we use piecewise cuda graph for prefill instead + tp_size=tp_size, + pp_size=1, + **kwargs, + ) + + tp_rank = dist.get_rank(get_tp_group()) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + model_config = ModelConfig.from_server_args(server_args) + model_runner = SGLangRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=torch.cuda.current_device(), + tp_rank=dist.get_rank(get_tp_group()), + tp_size=server_args.tp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=server_args.ep_size, + pp_rank=0, + pp_size=1, + server_args=server_args, + nccl_port=None, + ) + wrap_eagle3_logits_processors_in_module( + model_runner.model, return_full_logits=False + ) + + # Get hf_config from model_config for VLM attributes + hf_config = getattr(model_config, "hf_config", None) + + return cls(model_runner, hf_config=hf_config) + + def set_aux_hidden_states_layers( + self, aux_hidden_states_layers: Optional[List[int]] = None + ) -> None: + self.model_runner.model.set_eagle3_layers_to_capture(aux_hidden_states_layers) + + @torch.no_grad + def _extend( + self, + reqs, + capture_aux_hidden_states: bool = True, + return_last_hidden_states: bool = False, + return_logits: bool = False, + ): + # set the logits processor for the model runner + for name, module in self.model_runner.model.named_modules(): + if isinstance(module, LogitsProcessorForEAGLE3): + module.return_last_hidden_states = return_last_hidden_states + module.return_logits = return_logits + + cache_params = CacheInitParams( + disable=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + tree_cache = RadixCache(cache_params) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=tree_cache, + model_config=self.model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + self._maybe_prepare_mlp_sync_batch(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL + eagle3_output, _ = self.model_runner.forward(forward_batch) + + aux_hidden_states_list = None + input_lens = [len(req.origin_input_ids) for req in reqs] + + if return_logits: + logits = torch.split(eagle3_output.logits, input_lens, dim=0) + else: + logits = [None] * len(reqs) + + if capture_aux_hidden_states: + aux_hidden_states_list = torch.split( + eagle3_output.aux_hidden_states, input_lens, dim=0 + ) + else: + aux_hidden_states_list = [None] * len(reqs) + + if return_last_hidden_states: + last_hidden_states = torch.split( + eagle3_output.last_hidden_states, input_lens, dim=0 + ) + else: + last_hidden_states = [None] * len(reqs) + + # TODO: can we not clear? + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + return logits, aux_hidden_states_list, last_hidden_states + + def _maybe_prepare_mlp_sync_batch(self, batch: ScheduleBatch): + if require_mlp_sync(self.model_runner.server_args): + Scheduler.prepare_mlp_sync_batch_raw( + batch, + dp_size=self.model_runner.server_args.dp_size, + attn_tp_size=1, + tp_group=self.model_runner.tp_group, + get_idle_batch=None, + disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, + spec_algorithm=SpeculativeAlgorithm.NONE, + speculative_num_draft_tokens=None, + require_mlp_tp_gather=require_mlp_tp_gather( + self.model_runner.server_args + ), + disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, + offload_tags=set(), + ) + + def extend( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + return_last_hidden_states: bool = False, + return_logits: bool = True, + ): + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache = [], [] + + if isinstance(input_ids, torch.Tensor): + input_ids = torch.split(input_ids, 1, dim=0) + attention_mask = torch.split(attention_mask, 1, dim=0) + loss_mask = torch.split(loss_mask, 1, dim=0) + + for idx, (input_id_, attention_mask_, loss_mask_) in enumerate( + zip( + input_ids, + attention_mask, + loss_mask, + ) + ): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + data_cache.append([input_id_, attention_mask_, loss_mask_]) + reqs.append(req) + + logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + reqs, + capture_aux_hidden_states=True, + return_last_hidden_states=return_last_hidden_states, + return_logits=return_logits, + ) + + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + + def get_rope_index( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Get M-RoPE position indices for VLM models like Qwen2.5-VL. + + This is a wrapper around MRotaryEmbedding.get_rope_index that uses + the VLM-specific attributes initialized from hf_config. + + Args: + input_ids: (batch_size, seq_len) input token IDs + image_grid_thw: (num_images, 3) image grid dimensions (t, h, w) + video_grid_thw: (num_videos, 3) video grid dimensions (t, h, w) + second_per_grid_ts: Optional temporal information for videos + attention_mask: (batch_size, seq_len) attention mask + + Returns: + position_ids: (3, batch_size, seq_len) M-RoPE position IDs + rope_deltas: Optional position deltas for incremental decoding + """ + if not self.is_vlm: + raise ValueError("get_rope_index is only available for VLM models") + + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + position_ids, rope_deltas = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + tokens_per_second=self.tokens_per_second, + ) + + return position_ids, rope_deltas + + def extend_vlm( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + return_last_hidden_states: bool = False, + return_logits: bool = True, + pixel_values: Optional[List[torch.Tensor]] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, + ): + """ + Args: + input_ids: (batch_size, seq_len) or List of (1, seq_len) tensors + attention_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + loss_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + pixel_values: List of pixel_values tensors, one per sample in batch + image_grid_thw: List of image_grid_thw tensors, one per sample in batch + """ + mm_utils.embedding_cache.clear() + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache = [], [] + + # Split tensors if needed + if isinstance(input_ids, torch.Tensor): + batch_size = input_ids.shape[0] + input_ids = torch.split(input_ids, 1, dim=0) + attention_mask = torch.split(attention_mask, 1, dim=0) + loss_mask = torch.split(loss_mask, 1, dim=0) + else: + batch_size = len(input_ids) + # Process image_grid_thw - convert to list if needed + if image_grid_thw is None: + image_grid_thw = [None] * batch_size + elif not isinstance(image_grid_thw, (list, tuple)): + image_grid_thw = [image_grid_thw] + + # pixel_values is a single 2D tensor (total_patches, patch_dim) for Qwen2.5-VL + # We need to track offset and slice it based on image_grid_thw for each sample + pixel_values_offset = 0 # Track current offset in pixel_values + + for idx, (input_id_, attention_mask_, loss_mask_, image_grid_thw_) in enumerate( + zip( + input_ids, + attention_mask, + loss_mask, + image_grid_thw, + ) + ): + # Compute num_patches for this sample from image_grid_thw_ + # image_grid_thw_: (num_images, 3) where each row is (t, h, w) + if image_grid_thw_ is not None: + # Ensure image_grid_thw_ is 2D: (num_images, 3) + if image_grid_thw_.dim() == 1: + image_grid_thw_ = image_grid_thw_.unsqueeze(0) # (3,) -> (1, 3) + elif image_grid_thw_.dim() == 0: + raise ValueError( + f"image_grid_thw_ is 0-dim tensor, expected at least 1D. Value: {image_grid_thw_}" + ) + + # Calculate num_patches for this sample: sum(t * h * w) for all images + num_patches = ( + ( + image_grid_thw_[:, 0] + * image_grid_thw_[:, 1] + * image_grid_thw_[:, 2] + ) + .sum() + .item() + ) + num_patches = int(num_patches) + + # Slice pixel_values for this sample + pixel_value_ = pixel_values[ + pixel_values_offset : pixel_values_offset + num_patches + ] + pixel_values_offset += num_patches + else: + pixel_value_ = None + num_patches = 0 + + # Compute mrope positions for VLM models (e.g., Qwen2.5-VL) + input_id_flat = input_id_.view(-1) + + # Count image tokens + num_img_tokens = (input_id_flat == self.image_token_id).sum().item() + # print(f"[extend_vlm] num_img_tokens in input_ids: {num_img_tokens}") + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_id_flat.unsqueeze(0), + image_grid_thw=( + image_grid_thw_.cpu() if image_grid_thw_ is not None else None + ), + tokens_per_second=self.tokens_per_second, + ) + + offset = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.image_token_id + ) + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_value_, # torch.Tensor: (num_patches, patch_dim) + pad_value=self.image_token_id, # Required for placeholder tensor creation + offsets=offset, # List of (start, end) tuples + ) + mm_item.set("image_grid_thw", image_grid_thw_.cpu()) + mm_item.set_pad_value() + mm_inputs = MultimodalInputs( + mm_items=[mm_item], + im_token_id=self.image_token_id, + im_start_id=self.vision_start_token_id, + im_end_id=self.vision_end_token_id, + mrope_positions=( + mrope_positions.squeeze(1) if mrope_positions is not None else None + ), + mrope_position_delta=mrope_position_delta, + ) + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + input_id_list = pattern.pad_input_tokens( + input_id_.view(-1).tolist(), mm_inputs + ) + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_list, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + req.multimodal_inputs = mm_inputs + data_cache.append([input_id_, attention_mask_, loss_mask_]) + reqs.append(req) + + logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + reqs, + capture_aux_hidden_states=True, + return_last_hidden_states=return_last_hidden_states, + return_logits=return_logits, + ) + + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, + ) -> Eagle3TargetOutput: + """ + return: + data_for_draft: List[Dict[str, torch.Tensor]] of draft_batch_size, draft_micro_batch_size = 1 + - input_ids: (1, seq_len) + - attention_mask: (1, seq_len) + - loss_mask: (1, seq_len) + - target: (1, seq_len, vocab_size) or (1, seq_len, hidden_size) + - hidden_states: (1, seq_len, hidden_size) + - pixel_values: (patch_len, patch_width) + - image_grid_thw (batch_size, 3) + """ + if is_vlm: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend_vlm( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + ) + else: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + ) + ) + aux_hidden_states_out = [] + target_out = [] + loss_mask_out = [] + input_ids_out = [] + last_hidden_states_out = [] + + for idx, (data, logits, aux_hidden_states, last_hidden_states) in enumerate( + zip( + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + ) + ): + aux_hidden_states_out.append(aux_hidden_states.unsqueeze(0)) + loss_mask_out.append(data[2]) + input_ids_out.append(data[0]) + + # when generating hidden states for offline training, we don't compute logits and only keep the last_hidden_states + # when training online, we don't keep the last_hidden_states and only keep the logits + if logits is not None: + target_out.append(logits.unsqueeze(0)) + else: + target_out.append(None) + + if last_hidden_states is not None: + last_hidden_states_out.append(last_hidden_states.unsqueeze(0)) + else: + last_hidden_states_out.append(None) + + aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0) + + loss_mask_out = torch.cat(loss_mask_out, dim=0) + input_ids_out = torch.cat(input_ids_out, dim=0) + + if target_out[0] is not None: + target_out = torch.cat(target_out, dim=0) + else: + target_out = None + + if last_hidden_states_out[0] is not None: + last_hidden_states_out = torch.cat(last_hidden_states_out, dim=0) + else: + last_hidden_states_out = None + + target_out = padding(target_out, left=False) + input_ids_out = padding(input_ids_out, left=False) + loss_mask_out = loss_mask_out[..., None] + + return Eagle3TargetOutput( + hidden_states=aux_hidden_states_out, + target=target_out, + loss_mask=loss_mask_out, + input_ids=input_ids_out, + attention_mask=attention_mask, + last_hidden_states=last_hidden_states_out, + ) + + +class CustomEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "CustomEagle3TargetModel": + from specforge.modeling.auto import AutoDistributedTargetModel + + target_model = AutoDistributedTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + device=device, + **kwargs, + ) + return cls(target_model) + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + layers_to_output_hidden_states=self.aux_hidden_states_layers, + use_cache=False, + ) + + # For custom backends, the model implementation is responsible for only + # returning the requested layers in `outputs.hidden_states`. + hidden_states = torch.cat(outputs.hidden_states, dim=-1) + + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None].to(target.device) + + return Eagle3TargetOutput( + hidden_states=hidden_states, + target=target, + loss_mask=loss_mask, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +def get_eagle3_target_model( + pretrained_model_name_or_path: str, + backend: str = "sglang", + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Eagle3TargetModel: + if backend == "sglang": + return SGLangEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "hf": + return HFEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "custom": + return CustomEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + else: + raise ValueError(f"Invalid backend: {backend}") diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/__init__.py b/progress/SpecForge/specforge/modeling/target/sglang_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e02ab7b3950bf2405141a61a89c071f42a9a2a7 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/sglang_backend/__init__.py @@ -0,0 +1,4 @@ +from .model_runner import SGLangRunner +from .utils import wrap_eagle3_logits_processors_in_module + +__all__ = ["SGLangRunner", "wrap_eagle3_logits_processors_in_module"] diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eb5a20dbf12a50ea67cbd15ee6894a2426fc3ba Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb69217f9a3613396211a2e66c260246fe4c2b04 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d819cf02a202b8827388b9679d0fc56058733283 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76099b13214986bc36d6ab625f5b20b90f2390e2 Binary files /dev/null and b/progress/SpecForge/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc differ diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/model_runner.py b/progress/SpecForge/specforge/modeling/target/sglang_backend/model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..7b86a30747dd3c26007483dd7d436ee86d33dd1b --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/sglang_backend/model_runner.py @@ -0,0 +1,159 @@ +import logging +import os + +import torch +from sglang.srt.distributed import ( + get_pp_group, + get_tp_group, + get_world_group, + set_custom_all_reduce, + set_mscclpp_all_reduce, + set_torch_symm_mem_all_reduce, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + initialize_dp_attention, +) +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.utils import ( + cpu_has_amx_support, + get_available_gpu_memory, + get_bool_env_var, + is_hip, + is_npu, + monkey_patch_p2p_access_check, +) + +from .patch import ( + init_distributed_environment, + initialize_dp_attention, + initialize_model_parallel, +) + +_is_hip = is_hip() +_is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() + +# Use a small KV cache pool size for tests in CI +SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) + +# Detect stragger ranks in model loading +UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 + +logger = logging.getLogger(__name__) + + +class SGLangRunner(ModelRunner): + + def init_torch_distributed(self): + logger.info("Init torch distributed begin.") + + try: + torch.get_device_module(self.device).set_device(self.gpu_id) + except Exception: + logger.warning( + f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}" + ) + raise + + if self.device == "cuda": + if self.server_args.elastic_ep_backend == "mooncake": + backend = "mooncake" + if self.server_args.mooncake_ib_device: + mooncake_ib_device = self.server_args.mooncake_ib_device.split(",") + try: + from mooncake import ep as mooncake_ep + + mooncake_ep.set_device_filter(mooncake_ib_device) + except: + pass # A warning will be raised in `init_distributed_environment` + else: + backend = "nccl" + elif self.device == "xpu": + backend = "xccl" + elif self.device == "hpu": + backend = "hccl" + elif self.device == "cpu": + backend = "gloo" + elif self.device == "npu": + backend = "hccl" + + before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + if not self.server_args.enable_p2p_check: + monkey_patch_p2p_access_check() + + if self.server_args.dist_init_addr: + dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_mscclpp_all_reduce(self.server_args.enable_mscclpp) + set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) + + if not self.is_draft_worker: + if self.device == "cpu": + if _is_cpu_amx_available: + # Bind OpenMP threads to CPU cores + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + + # Set local size to hint SGLang to use shared memory based AllReduce + os.environ["LOCAL_SIZE"] = str(self.tp_size) + torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) + + @torch.library.register_fake("sgl_kernel::shm_allgather") + def _(data, dim): + return torch.cat([data] * self.tp_size, dim=dim) + + else: + logger.warning( + "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" + ) + + # Only initialize the distributed environment on the target model worker. + init_distributed_environment( + backend=backend, + world_size=self.tp_size * self.pp_size, + rank=self.tp_size * self.pp_rank + self.tp_rank, + local_rank=self.gpu_id, + ) + initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + expert_model_parallel_size=self.moe_ep_size, + duplicate_tp_group=self.server_args.enable_pdmux, + torch_compile=self.server_args.enable_piecewise_cuda_graph, + ) + initialize_dp_attention( + server_args=self.server_args, + model_config=self.model_config, + ) + + min_per_gpu_memory = get_available_gpu_memory( + self.device, + self.gpu_id, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, + ) + self.tp_group = get_tp_group() + self.pp_group = get_pp_group() + self.attention_tp_group = get_attention_tp_group() + + # Check memory for tensor parallelism + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) + if self.tp_size > 1 and not self.is_draft_worker: + if min_per_gpu_memory < local_gpu_memory * 0.9: + if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): + logger.warning( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + else: + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + + logger.info( + f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" + ) + return min_per_gpu_memory diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/patch.py b/progress/SpecForge/specforge/modeling/target/sglang_backend/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b48ed611f1d33ee8d198ad7e3cbdc0c58ed920dc --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/sglang_backend/patch.py @@ -0,0 +1,294 @@ +import logging +from typing import Optional + +import sglang.srt.distributed.parallel_state as parallel_state +import torch +import torch.distributed as dist +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import init_model_parallel_group +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + compute_dp_attention_local_info, + compute_dp_attention_world_info, +) +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_bool_env_var + +from specforge.distributed import get_tp_group as get_specforge_tp_group + +logger = logging.getLogger(__name__) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d backend=%s", + world_size, + rank, + backend, + ) + assert ( + torch.distributed.is_initialized() + ), "distributed environment should be initialized first" + + tp_group = get_specforge_tp_group() + world_size = dist.get_world_size() + tp_size = dist.get_world_size(tp_group) + num_tp_groups = world_size // tp_size + tp_ranks = [] + for i in range(num_tp_groups): + tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size))) + + parallel_state._WORLD = GroupCoordinator( + group_ranks=tp_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_pymscclpp=False, + use_custom_allreduce=False, + use_torch_symm_mem_all_reduce=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + use_npu_communicator=False, + group_name="world", + ) + # we destroy the newly created world group and replace it + # with the existing tp group from specforge to save CUDA memory + group_to_destroy = parallel_state._WORLD.device_group + parallel_state._WORLD.device_group = tp_group + dist.destroy_process_group(group_to_destroy) + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, + duplicate_tp_group: bool = False, + torch_compile: Optional[bool] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = parallel_state._WORLD.world_size + backend = backend or dist.get_backend(parallel_state._WORLD.device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = ( + dist.get_world_size() // tensor_model_parallel_size + ) + assert ( + parallel_state._TP is None + ), "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + parallel_state._TP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="tp", + pynccl_use_current_stream=duplicate_tp_group, + torch_compile=torch_compile, + ) + + if duplicate_tp_group: + assert ( + parallel_state._PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + assert ( + parallel_state._PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="pdmux_prefill_tp", + pynccl_use_current_stream=True, + torch_compile=torch_compile, + ) + parallel_state._TP.pynccl_comm.disabled = False + parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + + moe_ep_size = expert_model_parallel_size + + moe_tp_size = tensor_model_parallel_size // moe_ep_size + assert ( + parallel_state._MOE_EP is None + ), "expert model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_tp_size): + st = i * tensor_model_parallel_size + j + en = (i + 1) * tensor_model_parallel_size + j + ranks = list(range(st, en, moe_tp_size)) + group_ranks.append(ranks) + + parallel_state._MOE_EP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_ep", + ) + + assert ( + parallel_state._MOE_TP is None + ), "moe tensor model parallel group is already initialized" + if moe_ep_size == 1: + parallel_state._MOE_TP = parallel_state._TP + else: + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_ep_size): + st = i * tensor_model_parallel_size + j * moe_tp_size + en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size + ranks = list(range(st, en)) + group_ranks.append(ranks) + parallel_state._MOE_TP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_tp", + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = ( + dist.get_world_size() // pipeline_model_parallel_size + ) + assert ( + parallel_state._PP is None + ), "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list( + range(i, dist.get_world_size(), num_pipeline_model_parallel_groups) + ) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + parallel_state._PP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + ) + + +def initialize_dp_attention( + server_args: ServerArgs, + model_config: ModelConfig, +): + import sglang.srt.layers.dp_attention as dp_attention + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + + enable_dp_attention = server_args.enable_dp_attention + tp_size = server_args.tp_size + dp_size = server_args.dp_size + moe_dense_tp_size = server_args.moe_dense_tp_size + pp_size = server_args.pp_size + + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention + + ( + dp_attention._ATTN_TP_RANK, + dp_attention._ATTN_TP_SIZE, + dp_attention._ATTN_DP_RANK, + ) = compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size) + _, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size + ) + + if enable_dp_attention: + dp_attention._ATTN_DP_SIZE = dp_size + if moe_dense_tp_size is None: + dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE + else: + dp_attention._LOCAL_ATTN_DP_SIZE = max( + 1, dp_size // (tp_size // moe_dense_tp_size) + ) + else: + dp_attention._ATTN_DP_SIZE = 1 + dp_attention._LOCAL_ATTN_DP_SIZE = 1 + + tp_group = parallel_state.get_tp_group() + num_model_parallel_groups = dist.get_world_size() // (pp_size * tp_size) + mp_size = pp_size * tp_size + group_ranks = [] + + for i in range(num_model_parallel_groups): + ranks = [ + list(range(head, head + dp_attention._ATTN_TP_SIZE)) + for head in range( + mp_size * i, mp_size * (i + 1), dp_attention._ATTN_TP_SIZE + ) + ] + group_ranks.extend(ranks) + + dp_attention._ATTN_TP_GROUP = GroupCoordinator( + group_ranks, + tp_group.local_rank, + torch.distributed.get_backend(tp_group.device_group), + use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, + use_pymscclpp=False, + use_custom_allreduce=False, + use_torch_symm_mem_all_reduce=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + use_npu_communicator=False, + group_name="attention_tp", + ) + # print(f"{parallel_state._ATTN_TP_GROUP=}") + + _DpGatheredBufferWrapper.set_metadata( + hidden_size=model_config.hidden_size, + dtype=model_config.dtype, + device=torch.device(server_args.device), + ) diff --git a/progress/SpecForge/specforge/modeling/target/sglang_backend/utils.py b/progress/SpecForge/specforge/modeling/target/sglang_backend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..441adbb5808c0614789ac0e3c033dd53788a0e37 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/sglang_backend/utils.py @@ -0,0 +1,165 @@ +""" +This file contains the wrapper for the SGL model. +""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +from sglang.srt.layers.logits_processor import ( + LogitsMetadata, + LogitsProcessor, + LogitsProcessorOutput, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args + + +@dataclass +class ReplacedLogitsProcessorEagle3Output: + """ + A dataclass to store the logits and aux hidden states needed for EAGLE3. + """ + + logits: torch.Tensor + aux_hidden_states: torch.Tensor + last_hidden_states: Optional[torch.Tensor] = None + + +def replaced_logits_processor_forward_for_eagle3( + self, + input_ids, + hidden_states, + lm_head, + logits_metadata: Union[LogitsMetadata, ForwardBatch], + aux_hidden_states: Optional[torch.Tensor] = None, + return_last_hidden_states: bool = False, + return_logits: bool = False, +) -> LogitsProcessorOutput: + """ + This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py. + The modification is to return the logits and aux hidden states instead of the last hidden states. + """ + + if isinstance(logits_metadata, ForwardBatch): + logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) + + # Check if multi-item scoring is enabled via server args (only for prefill-only requests) + multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter + if multi_item_delimiter is not None and logits_metadata.is_prefill_only: + return self.compute_logprobs_for_multi_item_scoring( + input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter + ) + + # Get the last hidden states and last logits for the next token prediction + if ( + logits_metadata.forward_mode.is_decode_or_idle() + or logits_metadata.forward_mode.is_target_verify() + or logits_metadata.forward_mode.is_draft_extend_v2() + ): + pruned_states = hidden_states + if aux_hidden_states is not None: + aux_pruned_states = [hidden for hidden in aux_hidden_states] + sample_indices = None + input_logprob_indices = None + else: + raise RuntimeError( + f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}" + ) + + if return_last_hidden_states: + last_hidden_states = pruned_states + else: + last_hidden_states = None + + if return_logits: + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + else: + logits = None + + # get the aux hidden states + hidden_states_to_store: Optional[torch.Tensor] = None + if logits_metadata.capture_hidden_mode.need_capture(): + if logits_metadata.capture_hidden_mode.is_full(): + if aux_hidden_states is not None: + aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) + hidden_states_to_store = aux_hidden_states + else: + hidden_states_to_store = hidden_states + elif logits_metadata.capture_hidden_mode.is_last(): + # Get the last token hidden states. If sample_indices is None, + # pruned states only contain the last tokens already. + if aux_hidden_states is not None: + aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) + hidden_states_to_store = ( + aux_pruned_states[sample_indices] + if sample_indices is not None + else aux_pruned_states + ) + else: + hidden_states_to_store = ( + pruned_states[sample_indices] + if sample_indices is not None + else pruned_states + ) + else: + assert False, "Should never reach" + + assert ( + not logits_metadata.extend_return_logprob + ), "extend_return_logprob is not supported" + # Decode mode or extend mode without return_logprob. + return ReplacedLogitsProcessorEagle3Output( + logits=logits, + aux_hidden_states=hidden_states_to_store, + last_hidden_states=last_hidden_states, + ) + + +class LogitsProcessorForEAGLE3(torch.nn.Module): + def __init__( + self, + logits_processor: LogitsProcessor, + return_last_hidden_states: bool = False, + return_logits: bool = False, + ): + super().__init__() + self.logits_processor = logits_processor + self.return_last_hidden_states = return_last_hidden_states + self.return_logits = return_logits + + def forward( + self, + input_ids, + hidden_states, + lm_head, + logits_metadata, + aux_hidden_states: Optional[torch.Tensor] = None, + ) -> LogitsProcessorOutput: + logits_metadata.forward_mode = ForwardMode.DECODE + ret = replaced_logits_processor_forward_for_eagle3( + self.logits_processor, + input_ids, + hidden_states, + lm_head, + logits_metadata, + aux_hidden_states, + self.return_last_hidden_states, + self.return_logits, + ) + return ret + + +def wrap_eagle3_logits_processors_in_module( + module: nn.Module, return_full_logits: bool = False +): + """ + This function will wrap the SGLang's original logits processor with the modified one for EAGLE3. + """ + for name, submodule in module.named_modules(): + if isinstance(submodule, LogitsProcessor): + wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) + setattr(module, name, wrapped) + print(f"wrapped {name} with LogitsProcessorForEAGLE3") diff --git a/progress/SpecForge/specforge/modeling/target/target_head.py b/progress/SpecForge/specforge/modeling/target/target_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7231117cefa1903e733ead6305642318422e64b9 --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/target_head.py @@ -0,0 +1,92 @@ +import glob +import json +import os +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers import AutoConfig + +from specforge.utils import padding + + +class TargetHead(nn.Module): + def __init__(self, model_path, trust_remote_code: bool = False): + super().__init__() + self.config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + self.fc = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + @classmethod + def from_pretrained( + cls, + model_path, + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + ) -> "TargetHead": + target_head = cls(model_path, trust_remote_code=trust_remote_code) + target_head.load_weights( + model_path=model_path, + lm_head_key=lm_head_key, + cache_dir=cache_dir, + ) + target_head.freeze_weights() + target_head = target_head.eval().cuda().to(torch.bfloat16) + return target_head + + @torch.no_grad() + def load_weights( + self, + model_path, + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + ): + if os.path.exists(model_path): + self.model_path = model_path + else: + self.model_path = snapshot_download(repo_id=model_path) + + # model_path is a local directory + # check if there is file ending with index.json + glob_path = os.path.join(self.model_path, "*.index.json") + index_json_path = glob.glob(glob_path) + + if len(index_json_path) == 0: + raise FileNotFoundError(f"No index.json file found in {self.model_path}") + if len(index_json_path) > 1: + raise FileNotFoundError( + f"Multiple index.json files found in {self.model_path}" + ) + index_json_path = index_json_path[0] + + with open(index_json_path, "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][lm_head_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open( + os.path.join(self.model_path, ckpt_file), framework="pt" + ) as f: + lm_head = f.get_tensor(lm_head_key) + else: + state_dict = torch.load(os.path.join(self.model_path, ckpt_file)) + lm_head = state_dict[lm_head_key] + self.fc.weight.copy_(lm_head) + + def freeze_weights(self): + for param in self.fc.parameters(): + param.requires_grad = False + + def forward(self, hidden_states): + return self.fc(hidden_states) + + def preprocess(self, input_ids, target, loss_mask): + # apply pading + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None] + return input_ids, target, loss_mask diff --git a/progress/SpecForge/specforge/modeling/target/target_utils.py b/progress/SpecForge/specforge/modeling/target/target_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9377753fce2f1391fc14abf2f47a8634c57f0eee --- /dev/null +++ b/progress/SpecForge/specforge/modeling/target/target_utils.py @@ -0,0 +1,134 @@ +import glob +import json +import os +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers import AutoConfig + + +class TargetEmbeddingsAndHead(nn.Module): + """ + Efficiently loads only the embedding layer and lm_head from a pretrained model. + Avoids loading the full model into memory. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + @classmethod + def from_pretrained( + cls, + model_path: str, + embed_key: str = "model.embed_tokens.weight", + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + trust_remote_code: bool = False, + ) -> "TargetEmbeddingsAndHead": + + # 1. Load Config + config = AutoConfig.from_pretrained( + model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code + ) + instance = cls(config) + + # 2. Resolve Model Path (Handle Hub) + local_model_path = model_path + if not os.path.exists(local_model_path): + try: + local_model_path = snapshot_download( + repo_id=model_path, cache_dir=cache_dir + ) + except: + pass # Maybe it's a local path that looks like a repo ID but doesn't exist? + + # 3. Load Weights Efficiently + instance._load_weights(local_model_path, embed_key, lm_head_key) + + # 4. Move to Device & Freeze + instance.to(device=device, dtype=dtype) + instance.eval() + instance.requires_grad_(False) + + return instance + + def _load_weights(self, model_path: str, embed_key: str, lm_head_key: str): + # Locate index.json + index_files = glob.glob(os.path.join(model_path, "*.index.json")) + + weight_map = {} + if index_files: + # Sharded Checkpoint + with open(index_files[0], "r") as f: + index = json.load(f) + + # Find which file contains our keys + weight_map = index.get("weight_map", {}) + files_to_load = {} + + if embed_key in weight_map: + files_to_load[embed_key] = weight_map[embed_key] + else: + # Fallback: sometimes keys are prefixed differently? + print( + f"Warning: {embed_key} not found in weight_map. Keys available: {list(weight_map.keys())[:5]}..." + ) + + if lm_head_key in weight_map: + files_to_load[lm_head_key] = weight_map[lm_head_key] + + # Load specific files + for key, filename in files_to_load.items(): + file_path = os.path.join(model_path, filename) + self._load_key_from_file(file_path, key) + + else: + # Non-sharded Checkpoint (single file) + # Try finding .safetensors or .bin + safetensors = glob.glob(os.path.join(model_path, "*.safetensors")) + bins = glob.glob(os.path.join(model_path, "*.bin")) + + target_file = None + if safetensors: + target_file = safetensors[0] + elif bins: + target_file = bins[0] + + if target_file: + self._load_key_from_file(target_file, embed_key) + self._load_key_from_file(target_file, lm_head_key) + else: + raise FileNotFoundError(f"No checkpoint file found in {model_path}") + + def _load_key_from_file(self, file_path: str, key: str): + tensor = None + if file_path.endswith(".safetensors"): + with safe_open(file_path, framework="pt") as f: + if key in f.keys(): + tensor = f.get_tensor(key) + else: + # torch.load loads full dict, less efficient but works + state_dict = torch.load(file_path, map_location="cpu") + if key in state_dict: + tensor = state_dict[key] + del state_dict # Free immediately + + if tensor is not None: + if key.endswith("embed_tokens.weight"): + self.embed_tokens.weight.data.copy_(tensor) + print(f"Loaded embedding weights from {file_path}") + elif key.endswith("lm_head.weight"): + self.lm_head.weight.data.copy_(tensor) + print(f"Loaded lm_head weights from {file_path}") + else: + print(f"Warning: Key {key} not found in {file_path}") diff --git a/progress/SpecForge/specforge/modeling/utils.py b/progress/SpecForge/specforge/modeling/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cdd45642c0761e7178c1990e2e1bab6420b15ea --- /dev/null +++ b/progress/SpecForge/specforge/modeling/utils.py @@ -0,0 +1,11 @@ +import torch + + +@torch.no_grad() +def padding(tensor, left=True): + zeropadding = torch.zeros_like(tensor[:, -1:]) + if left: + tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) + else: + tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) + return tensor diff --git a/progress/SpecForge/specforge/optimizer.py b/progress/SpecForge/specforge/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ed15a6616fd4828f82e8e1220e60126a915104f3 --- /dev/null +++ b/progress/SpecForge/specforge/optimizer.py @@ -0,0 +1,146 @@ +import json + +import torch + +from specforge.lr_scheduler import CosineAnnealingWarmupLR +from specforge.utils import print_on_rank0 + + +class BF16Optimizer: + def __init__( + self, + model, + lr, + weight_decay=0.0, + max_grad_norm=0.5, + total_steps=800_000, + warmup_ratio=0.015, + use_fp32_params=True, + optimizer_type="adamw", + optimizer_config=None, + ): + self.model = model + self.model_params = [p for p in model.parameters() if p.requires_grad] + self.max_grad_norm = max_grad_norm + self.use_fp32_params = use_fp32_params + self.optimizer_type = optimizer_type + + if use_fp32_params: + self.fp32_params = [ + p.detach().clone().to(torch.float32) for p in self.model_params + ] + for mp in self.fp32_params: + mp.requires_grad = True + self.optim_params = self.fp32_params + else: + self.fp32_params = None + self.optim_params = self.model_params + + self.optimizer = self._create_optimizer(lr, weight_decay, optimizer_config) + self.scheduler = CosineAnnealingWarmupLR( + self.optimizer, + total_steps=total_steps, + warmup_steps=int(warmup_ratio * total_steps), + ) + + def _create_optimizer(self, lr, weight_decay, optimizer_config): + if self.optimizer_type == "adamw": + return torch.optim.AdamW( + self.optim_params, lr=lr, weight_decay=weight_decay + ) + elif self.optimizer_type == "adamw_8bit": + import bitsandbytes as bnb + + return bnb.optim.AdamW8bit( + self.optim_params, lr=lr, weight_decay=weight_decay + ) + elif self.optimizer_type == "apollo": + from apollo_torch import APOLLOAdamW + + assert optimizer_config is not None, ( + "optimizer_config path is required when optimizer_type='apollo'" + ) + with open(optimizer_config, "r") as f: + apollo_cfg = json.load(f) + param_groups = self._build_apollo_param_groups( + apollo_cfg, lr, weight_decay + ) + return APOLLOAdamW(param_groups, lr=lr, weight_decay=weight_decay) + else: + raise ValueError( + f"Unknown optimizer_type: {self.optimizer_type}. " + f"Supported types: adamw, adamw_8bit, apollo" + ) + + def _build_apollo_param_groups(self, apollo_cfg, lr, weight_decay): + """Build param groups for APOLLO optimizer. + + Splits parameters into two groups: + - non_lowrank_params: 1D params (bias, layernorm) - standard Adam update + - lowrank_params: nD params (weight matrices) - low-rank projected update + """ + lowrank_params = [] + non_lowrank_params = [] + for p in self.optim_params: + if p.ndim >= 2: + lowrank_params.append(p) + else: + non_lowrank_params.append(p) + + param_groups = [ + {"params": non_lowrank_params, "lr": lr, "weight_decay": weight_decay}, + { + "params": lowrank_params, + "lr": lr, + "weight_decay": weight_decay, + "rank": apollo_cfg.get("rank", 1), + "proj": apollo_cfg.get("proj", "random"), + "scale_type": apollo_cfg.get("scale_type", "tensor"), + "scale": apollo_cfg.get("scale", 128), + "update_proj_gap": apollo_cfg.get("update_proj_gap", 200), + "proj_type": apollo_cfg.get("proj_type", "std"), + }, + ] + return param_groups + + def step(self): + if self.use_fp32_params: + with torch.no_grad(): + for p, mp in zip(self.model_params, self.fp32_params): + mp.grad = ( + p.grad.detach().to(torch.float32) + if p.grad is not None + else None + ) + torch.nn.utils.clip_grad_norm_(self.fp32_params, self.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.model_params, self.max_grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + if self.use_fp32_params: + with torch.no_grad(): + for p, mp in zip(self.model_params, self.fp32_params): + p.data.copy_(mp.data.to(p.dtype)) + p.grad = None + else: + with torch.no_grad(): + for p in self.model_params: + p.grad = None + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict["optimizer_state_dict"]) + print_on_rank0("Successfully loaded optimizer state_dict.") + self.scheduler.load_state_dict(state_dict["scheduler_state_dict"]) + print_on_rank0("Successfully loaded scheduler state_dict.") + + def state_dict(self): + return { + "optimizer_state_dict": self.optimizer.state_dict(), + "scheduler_state_dict": self.scheduler.state_dict(), + } + + def get_learning_rate(self): + return self.optimizer.param_groups[0]["lr"] diff --git a/progress/SpecForge/specforge/tracker.py b/progress/SpecForge/specforge/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7498c156b3194c7afcef625801a6d33a41364 --- /dev/null +++ b/progress/SpecForge/specforge/tracker.py @@ -0,0 +1,297 @@ +# tracker.py + +import abc +import netrc +import os +from typing import Any, Dict, Optional + +import torch.distributed as dist + +# --- Lazy Imports --- +# These libraries are imported only when their respective trackers are used. +try: + import wandb +except ImportError: + wandb = None + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + +try: + import swanlab +except ImportError: + swanlab = None + +try: + import mlflow +except ImportError: + mlflow = None + + +# --- End Lazy Imports --- + + +class Tracker(abc.ABC): + """ + Abstract Base Class for experiment trackers. + + Each tracker implementation should handle its own initialization, logging, + and cleanup. It should also provide a class method to validate + command-line arguments before initialization. + """ + + def __init__(self, args, output_dir: str): + self.args = args + self.output_dir = output_dir + self.rank = dist.get_rank() + self.is_initialized = False + + @classmethod + @abc.abstractmethod + def validate_args(cls, parser, args) -> None: + """ + Validate necessary arguments for this tracker. + This method is called during argument parsing. + It should raise an error if required arguments are missing. + """ + + @abc.abstractmethod + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None) -> None: + """ + Log metrics to the tracker. + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Close the tracker and clean up resources. + """ + + +class NoOpTracker(Tracker): + """A tracker that does nothing, for when no tracking is desired.""" + + @classmethod + def validate_args(cls, parser, args): + pass # No arguments to validate + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + self.is_initialized = True # Considered initialized to do nothing + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + pass # Do nothing + + def close(self): + pass # Do nothing + + +class WandbTracker(Tracker): + """Tracks experiments using Weights & Biases.""" + + @classmethod + def validate_args(cls, parser, args): + if wandb is None: + parser.error( + "To use --report-to wandb, you must install wandb: 'pip install wandb'" + ) + + if args.wandb_key is not None: + return + + if "WANDB_API_KEY" in os.environ: + args.wandb_key = os.environ["WANDB_API_KEY"] + return + + try: + netrc_path = os.path.expanduser("~/.netrc") + if os.path.exists(netrc_path): + netrc_file = netrc.netrc(netrc_path) + if "api.wandb.ai" in netrc_file.hosts: + _, _, password = netrc_file.authenticators("api.wandb.ai") + if password: + args.wandb_key = password + return + except (FileNotFoundError, netrc.NetrcParseError): + pass + + if args.wandb_key is None: + parser.error( + "When --report-to is 'wandb', you must provide a wandb API key via one of:\n" + " 1. --wandb-key argument\n" + " 2. WANDB_API_KEY environment variable\n" + " 3. `wandb login` command" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + wandb.login(key=args.wandb_key) + wandb.init( + project=args.wandb_project, name=args.wandb_name, config=vars(args) + ) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + wandb.log(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized and wandb.run: + wandb.finish() + self.is_initialized = False + + +class SwanlabTracker(Tracker): + """Tracks experiments using SwanLab.""" + + @classmethod + def validate_args(cls, parser, args): + if swanlab is None: + parser.error( + "To use --report-to swanlab, you must install swanlab: 'pip install swanlab'" + ) + + if args.swanlab_key is not None: + return + if "SWANLAB_API_KEY" in os.environ: + args.swanlab_key = os.environ["SWANLAB_API_KEY"] + return + # Swanlab can run in anonymous mode if no key is provided in a non-distributed env. + # However, a key is often required for distributed runs to sync correctly. + if ( + dist.is_initialized() + and dist.get_world_size() > 1 + and args.swanlab_key is None + ): + parser.error( + "In a distributed environment, when --report-to is 'swanlab', you must provide a swanlab API key via:\n" + " 1. --swanlab-key argument\n" + " 2. SWANLAB_API_KEY environment variable" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + if args.swanlab_key: + swanlab.login(api_key=args.swanlab_key) + + swanlog_dir = os.path.join(output_dir, "swanlog") + os.makedirs(swanlog_dir, exist_ok=True) + swanlab.init( + project=args.swanlab_project, + experiment_name=args.swanlab_name, + config=vars(args), + logdir=swanlog_dir, + ) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + swanlab.log(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized and swanlab.get_run() is not None: + swanlab.finish() + self.is_initialized = False + + +class TensorboardTracker(Tracker): + """Tracks experiments using TensorBoard.""" + + @classmethod + def validate_args(cls, parser, args): + if SummaryWriter is None: + parser.error( + "To use --report-to tensorboard, you must have tensorboard installed: 'pip install tensorboard'" + ) + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + log_dir = os.path.join(output_dir, "runs") + self.writer = SummaryWriter(log_dir=log_dir) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + for key, value in log_dict.items(): + if isinstance(value, (int, float)): + self.writer.add_scalar(key, value, global_step=step) + + def close(self): + if self.rank == 0 and self.is_initialized: + self.writer.close() + self.is_initialized = False + + +class MLflowTracker(Tracker): + """Tracks experiments using MLflow.""" + + @classmethod + def validate_args(cls, parser, args): + if mlflow is None: + parser.error( + "To use --report-to mlflow, you must install mlflow: 'pip install mlflow'" + ) + # Set tracking URI from environment variable if not explicitly provided + if args.mlflow_tracking_uri is None and "MLFLOW_TRACKING_URI" in os.environ: + args.mlflow_tracking_uri = os.environ["MLFLOW_TRACKING_URI"] + elif args.mlflow_tracking_uri is None: + print( + "Warning: MLflow tracking URI not set. Defaulting to local './mlruns'." + ) + + # Set experiment name from environment variable if not explicitly provided + if ( + args.mlflow_experiment_name is None + and "MLFLOW_EXPERIMENT_NAME" in os.environ + ): + args.mlflow_experiment_name = os.environ["MLFLOW_EXPERIMENT_NAME"] + + def __init__(self, args, output_dir: str): + super().__init__(args, output_dir) + if self.rank == 0: + if args.mlflow_tracking_uri: + mlflow.set_tracking_uri(args.mlflow_tracking_uri) + + # This will either use the set URI or the default + mlflow.set_experiment(args.mlflow_experiment_name) + mlflow.start_run(run_name=args.mlflow_run_name) + mlflow.log_params(vars(args)) + self.is_initialized = True + + def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): + if self.rank == 0 and self.is_initialized: + # MLflow's log_metrics takes a dictionary directly + mlflow.log_metrics(log_dict, step=step) + + def close(self): + if self.rank == 0 and self.is_initialized: + mlflow.end_run() + self.is_initialized = False + + +# --- Tracker Factory --- +TRACKER_REGISTRY = { + "wandb": WandbTracker, + "swanlab": SwanlabTracker, + "tensorboard": TensorboardTracker, + "mlflow": MLflowTracker, + "none": NoOpTracker, +} + + +def get_tracker_class(report_to: str) -> Optional[Tracker]: + """Returns the tracker class based on the name.""" + return TRACKER_REGISTRY.get(report_to) + + +def create_tracker(args, output_dir: str) -> Tracker: + """Factory function to create an experiment tracker instance.""" + tracker_class = get_tracker_class(args.report_to) + if not tracker_class: + raise ValueError(f"Unsupported report_to type: {args.report_to}") + return tracker_class(args, output_dir) diff --git a/progress/SpecForge/specforge/utils.py b/progress/SpecForge/specforge/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..59724a82463a64e8937e87da61f6a38179fe4f3b --- /dev/null +++ b/progress/SpecForge/specforge/utils.py @@ -0,0 +1,359 @@ +import json +import logging +import os +import re +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard, distribute_tensor +from transformers import AutoConfig, PretrainedConfig + +logger = logging.getLogger(__name__) + + +@contextmanager +def rank_0_priority(): + rank = dist.get_rank() + + if rank == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def default_torch_dtype(dtype: torch.dtype): + current_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(current_dtype) + + +@torch.no_grad() +def padding(tensor, left=True): + zeropadding = torch.zeros_like(tensor[:, -1:]) + if left: + tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) + else: + tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) + return tensor + + +def load_config_from_file(config_path: str): + with open(config_path, "r") as f: + config = json.load(f) + + return PretrainedConfig.from_dict(config) + + +def print_with_rank(message): + if dist.is_available() and dist.is_initialized(): + logger.info(f"rank {dist.get_rank()}: {message}") + else: + logger.info(f"non-distributed: {message}") + + +def print_args_with_dots(args): + if dist.get_rank() == 0: + args_dict = vars(args) + max_key_length = max(len(key) for key in args_dict.keys()) + total_width = 50 + + print("\n -----------【args】-----------") + for key, value in args_dict.items(): + key_str = f"{key:<{max_key_length}}" + value_str = str(value) + dot_count = total_width - len(key_str) - len(value_str) + dot_fill = "·" * dot_count + print(f"{key_str} {dot_fill} {value_str}") + + +def print_on_rank0(message): + if dist.get_rank() == 0: + logger.info(message) + + +def get_last_checkpoint(folder, prefix="epoch"): + content = os.listdir(folder) + _re_checkpoint = re.compile(r"^" + prefix + r"_(\d+)$") + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None + and os.path.isdir(os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + return os.path.join( + folder, + max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])), + ) + + +def generate_draft_model_config( + target_model_path: str, template_config_path: str = None, cache_dir: str = None +): + """ + Auto-generate draft model config based on target model parameters aligned with template config + + Args: + target_model_path (str): Path to the target model + template_config_path (str, optional): Template config file path, defaults to llama3-8B-eagle3.json + cache_dir (str, optional): Cache directory + + Returns: + dict: Generated draft model config dictionary + """ + # Get target model config + target_config = AutoConfig.from_pretrained(target_model_path, cache_dir=cache_dir) + + # If no template specified, use default llama3-8B-eagle3.json + if template_config_path is None: + # Use the script execution directory as base + import sys + + script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + project_root = os.path.dirname(script_dir) # Go up one level from scripts/ + template_config_path = os.path.join( + project_root, "configs", "llama3-8B-eagle3.json" + ) + + # Read template config + with open(template_config_path, "r") as f: + draft_config = json.load(f) + + # Adjust architecture config based on target model type + if hasattr(target_config, "model_type"): + # Default to llama architecture + draft_config["model_type"] = "llama" + + # Align key parameters + param_mappings = { + "vocab_size": "vocab_size", + "hidden_size": "hidden_size", + "num_attention_heads": "num_attention_heads", + "num_key_value_heads": "num_key_value_heads", + "intermediate_size": "intermediate_size", + "max_position_embeddings": "max_position_embeddings", + "rms_norm_eps": "rms_norm_eps", + "hidden_act": "hidden_act", + "bos_token_id": "bos_token_id", + "eos_token_id": "eos_token_id", + "torch_dtype": "torch_dtype", + } + + # Copy parameters from target model to draft config + for target_param, draft_param in param_mappings.items(): + if hasattr(target_config, target_param): + value = getattr(target_config, target_param) + # Special handling for torch_dtype to make it JSON serializable + if target_param == "torch_dtype" and isinstance(value, torch.dtype): + value = str(value).replace("torch.", "") + draft_config[draft_param] = value + + # Special handling for some parameters + # Ensure num_hidden_layers is always 1 (EAGLE3 feature) + draft_config["num_hidden_layers"] = 1 + + # Keep some fixed draft model specific parameters + draft_config["tie_word_embeddings"] = False + draft_config["use_cache"] = True + + # If template doesn't have draft_vocab_size, set default + if "draft_vocab_size" not in draft_config: + draft_config["draft_vocab_size"] = 32000 # Default value + + return draft_config + + +def save_draft_model_config(config_dict: dict, output_path: str): + """ + Save draft model config to file + + Args: + config_dict (dict): Config dictionary + output_path (str): Output file path + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + json.dump(config_dict, f, indent=2, ensure_ascii=False) + + print(f"Draft model config saved to: {output_path}") + + +def create_draft_config_from_target( + target_model_path: str, + output_dir: str = None, + template_config_path: str = None, + cache_dir: str = None, +): + """ + Convenient function to create draft model config file from target model + + Args: + target_model_path (str): Target model path + output_dir (str, optional): Output directory, defaults to configs folder in current directory + template_config_path (str, optional): Template config path + cache_dir (str, optional): Cache directory + + Returns: + str: Generated config file path + """ + # Generate config + rank = dist.get_rank() + + if rank == 0: + print_with_rank( + "No draft model config provided, auto-generating from target model..." + ) + config_dict = generate_draft_model_config( + target_model_path, template_config_path, cache_dir + ) + dist.barrier() + + # Determine output path + if output_dir is None: + # Use the script execution directory as base + import sys + + script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) + project_root = os.path.dirname(script_dir) # Go up one level from scripts/ + output_dir = os.path.join(project_root, "configs") + + # Extract model name from model path + model_name = target_model_path.split("/")[-1].lower() + output_filename = f"{model_name}-eagle3-auto.json" + output_path = os.path.join(output_dir, output_filename) + + # Save config + if rank == 0: + save_draft_model_config(config_dict, output_path) + print_with_rank(f"Auto-generated draft model config saved to: {output_path}") + dist.barrier() + + return output_path + + +def get_full_optimizer_state(optimizer_state_dict: dict): + """ + Convert optimizer state dict with DTensor to full tensors for saving + + Args: + optimizer_state_dict (dict): Optimizer state dict possibly containing DTensors + Returns: + dict: Optimizer state dict with full tensors + """ + full_optimizer_state_dict = { + k: v for k, v in optimizer_state_dict.items() if k != "state" + } + if "state" in optimizer_state_dict: + full_optimizer_state_dict["state"] = { + param_id: { + state_key: ( + state_tensor.full_tensor() + if isinstance(state_tensor, torch.distributed.tensor.DTensor) + else state_tensor + ) + for state_key, state_tensor in param_state.items() + } + for param_id, param_state in optimizer_state_dict["state"].items() + } + return full_optimizer_state_dict + + +def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh): + """ + Shards the optimizer state tensors of a BF16Optimizer instance using DTensor. + + Args: + bf16_optimizer (BF16Optimizer): An instance of BF16Optimizer, which contains + the actual optimizer (e.g., torch.optim.Adam) as its `.optimizer` attribute. + """ + + optim = bf16_optimizer.optimizer + + for group in optim.param_groups: + for p in group["params"]: + if not isinstance(p, DTensor): + continue + + state = optim.state.get(p, None) + if state is None: + continue + + mesh = device_mesh + placements = (Shard(dim=0),) + + for k, v in list(state.items()): + if k == "step": + continue + + if isinstance(v, DTensor): + continue + + if not isinstance(v, torch.Tensor): + continue + + state[k] = distribute_tensor( + v.to(p.device), device_mesh=mesh, placements=placements + ) + + +def safe_conversations_generator(file_path): + """ + Generator that: + 1. Extracts the 'conversations' field. + 2. Preserves all original fields within each message. + 3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility). + """ + with open(file_path, "r", encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + raw_convs = row.get("conversations", []) + + # 1. Ensure 'conversations' is a list + if not isinstance(raw_convs, list): + # If it's None or some unexpected type, treat as empty or skip + if raw_convs is None: + raw_convs = [] + else: + # Edge case: 'conversations' is a plain string or non-iterable—skip this line + logger.warning( + f"Line {i + 1}: 'conversations' is not a list. Please check!" + ) + continue + + cleaned_convs = [] + for msg in raw_convs: + # 2. Ensure each item in the list is a dictionary + if not isinstance(msg, dict): + # Skip if an element is not a dict (e.g., malformed like ["user", "hi"]) + continue + + # 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.) + new_msg = {} + for k, v in msg.items(): + # If the value is a list or dict, serialize it to a JSON string + # This ensures Arrow treats the column as string type instead of list/struct + if isinstance(v, (list, dict)): + new_msg[k] = json.dumps(v, ensure_ascii=False) + else: + # Keep primitive types (str, int, float, bool, None) unchanged + new_msg[k] = v + + cleaned_convs.append(new_msg) + + # Yield only the processed 'conversations' + yield {"conversations": cleaned_convs} + + except Exception as e: + logger.warning(f"Skipping line {i + 1}: {e}") + continue diff --git a/progress/SpecForge/tests/__init__.py b/progress/SpecForge/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/SpecForge/tests/utils.py b/progress/SpecForge/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cc20609907eeacf62c7e76be4bee5e1caf1ea2 --- /dev/null +++ b/progress/SpecForge/tests/utils.py @@ -0,0 +1,107 @@ +import os +import socket +import subprocess +import time + +import requests +from sglang.utils import print_highlight + + +def is_port_in_use(port: int) -> bool: + """Check if a port is in use""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return False + except OSError: + return True + + +def get_available_port(): + # get a random available port + # and try to find a port that is not in use + for port in range(10000, 65535): + if not is_port_in_use(port): + return port + raise RuntimeError("No available port found") + + +def execute_shell_command( + command: str, disable_proxy: bool = False, enable_hf_mirror: bool = False +): + """ + Execute a shell command and return its process handle. + """ + command = command.replace("\\\n", " ").replace("\\", " ") + parts = command.split() + env = os.environ.copy() + + if disable_proxy: + env.pop("http_proxy", None) + env.pop("https_proxy", None) + env.pop("no_proxy", None) + env.pop("HTTP_PROXY", None) + env.pop("HTTPS_PROXY", None) + env.pop("NO_PROXY", None) + + if enable_hf_mirror: + env["HF_ENDPOINT"] = "https://hf-mirror.com" + return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT, env=env) + + +def wait_for_server( + base_url: str, timeout: int = None, disable_proxy: bool = False +) -> None: + """Wait for the server to be ready by polling the /v1/models endpoint. + + Args: + base_url: The base URL of the server + timeout: Maximum time to wait in seconds. None means wait forever. + """ + start_time = time.perf_counter() + + if disable_proxy: + http_proxy = os.environ.pop("http_proxy", None) + https_proxy = os.environ.pop("https_proxy", None) + no_proxy = os.environ.pop("no_proxy", None) + http_proxy_capitalized = os.environ.pop("HTTP_PROXY", None) + https_proxy_capitalized = os.environ.pop("HTTPS_PROXY", None) + no_proxy_capitalized = os.environ.pop("NO_PROXY", None) + + while True: + try: + response = requests.get( + f"{base_url}/v1/models", + headers={"Authorization": "Bearer None"}, + ) + if response.status_code == 200: + time.sleep(5) + print_highlight( + """\n + NOTE: Typically, the server runs in a separate terminal. + In this notebook, we run the server and notebook code together, so their outputs are combined. + To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue. + To reduce the log length, we set the log level to warning for the server, the default log level is info. + We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance. + """ + ) + break + + if timeout and time.perf_counter() - start_time > timeout: + raise TimeoutError("Server did not become ready within timeout period") + except requests.exceptions.RequestException: + time.sleep(1) + + if disable_proxy: + if http_proxy: + os.environ["http_proxy"] = http_proxy + if https_proxy: + os.environ["https_proxy"] = https_proxy + if no_proxy: + os.environ["no_proxy"] = no_proxy + if http_proxy_capitalized: + os.environ["HTTP_PROXY"] = http_proxy_capitalized + if https_proxy_capitalized: + os.environ["HTTPS_PROXY"] = https_proxy_capitalized + if no_proxy_capitalized: + os.environ["NO_PROXY"] = no_proxy_capitalized diff --git a/progress/github/SpecForge/.github/ISSUE_TEMPLATE/1-bug-report.yaml b/progress/github/SpecForge/.github/ISSUE_TEMPLATE/1-bug-report.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41fa058c4aff03bdeb9e04b06e5d2129fd2e57f1 --- /dev/null +++ b/progress/github/SpecForge/.github/ISSUE_TEMPLATE/1-bug-report.yaml @@ -0,0 +1,38 @@ +name: 🐞 Bug report +description: Create a report to help us reproduce and fix the bug +title: "[Bug] " +labels: ['Bug'] + +body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. I have searched related issues but cannot get the expected help. + - label: 2. The bug has not been fixed in the latest version. + - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. + - label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed. + - label: 5. Please use English, otherwise it will be closed. +- type: textarea + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Reproduction + description: | + What command or script did you run? Which **model** are you using? + placeholder: | + A placeholder for the command. + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please provide necessary environment information here. Otherwise the issue will be closed. + placeholder: Environment here. + validations: + required: true diff --git a/progress/github/SpecForge/.github/ISSUE_TEMPLATE/2-feature-request.yaml b/progress/github/SpecForge/.github/ISSUE_TEMPLATE/2-feature-request.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6fc81989429af5d3096a389db9adb6ec3993d60 --- /dev/null +++ b/progress/github/SpecForge/.github/ISSUE_TEMPLATE/2-feature-request.yaml @@ -0,0 +1,23 @@ +name: 🚀 Feature request +description: Suggest an idea for this project +title: "[Feature] " + +body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed. + - label: 2. Please use English, otherwise it will be closed. +- type: textarea + attributes: + label: Motivation + description: | + A clear and concise description of the motivation of the feature. + validations: + required: true +- type: textarea + attributes: + label: Related resources + description: | + If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. diff --git a/progress/github/SpecForge/.github/workflows/lint.yaml b/progress/github/SpecForge/.github/workflows/lint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3cf35a6be5986ecd8e9f90cef12a75438e8401d6 --- /dev/null +++ b/progress/github/SpecForge/.github/workflows/lint.yaml @@ -0,0 +1,22 @@ +name: Lint + +on: [ pull_request ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install pre-commit hook + run: | + python -m pip install pre-commit + pre-commit install + + - name: Linting + run: pre-commit run --all-files --show-diff-on-failure diff --git a/progress/github/SpecForge/.github/workflows/publish_docs.yaml b/progress/github/SpecForge/.github/workflows/publish_docs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27f4639d2eb35474f4865f57f9031e18df722942 --- /dev/null +++ b/progress/github/SpecForge/.github/workflows/publish_docs.yaml @@ -0,0 +1,72 @@ +name: Release Documentation + +on: + push: + branches: + - main + paths: + - "docs/**" + - "version.txt" + workflow_dispatch: + +concurrency: + group: release-docs-${{ github.ref }} + cancel-in-progress: true + +jobs: + deploy-github-pages: + runs-on: ubuntu-latest + if: github.repository == 'sgl-project/specforge' || github.repository == 'sleepcoo/SpecForge' + permissions: + contents: write + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'npm' + cache-dependency-path: docs/spec_bundle/package-lock.json + + - name: Install dependencies + run: | + sudo apt-get update && sudo apt-get install -y pandoc parallel retry + pip install -r docs/requirements.txt + + - name: Build spec bundle dashboard + run: | + # Copy logos to public directory + cp assets/logo.png docs/spec_bundle/public/logo.png + cp docs/_static/imgs/specbundle-logo.png docs/spec_bundle/public/specbundle-logo.png + cd docs/spec_bundle + npm ci + npm run build + # Clean up node_modules to prevent Sphinx from processing them + rm -rf node_modules + cd .. + + - name: Build documentation + run: | + cd docs + make compile + make html + # Copy SpecBundle to root of output directory + mkdir -p _build/html/SpecBundle + cp -r spec_bundle/dist/* _build/html/SpecBundle/ + + - name: Add .nojekyll file + run: | + touch ./docs/_build/html/.nojekyll + + - name: Deploy + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/_build/html diff --git a/progress/github/SpecForge/.github/workflows/publish_pypi.yaml b/progress/github/SpecForge/.github/workflows/publish_pypi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b2c68f1cd16ae6ada552340aef58265c7daeafb --- /dev/null +++ b/progress/github/SpecForge/.github/workflows/publish_pypi.yaml @@ -0,0 +1,33 @@ +name: Publish to PyPI + +on: + workflow_dispatch: + +jobs: + build-n-publish: + if: github.event_name == 'workflow_dispatch' + name: Build and publish Python distributions to PyPI + runs-on: ubuntu-latest + timeout-minutes: 20 + environment: + name: pypi + url: https://pypi.org/p/specforgeee + permissions: + id-token: write + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-python@v2 + with: + python-version: '3.11' + + - run: python setup.py sdist build + + # publish to PyPI if executed on the main branch + - name: Publish package to PyPI + id: publish + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} + verbose: true diff --git a/progress/github/SpecForge/.github/workflows/test.yaml b/progress/github/SpecForge/.github/workflows/test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..328dd8a17769929b692afac84729b34a2551448f --- /dev/null +++ b/progress/github/SpecForge/.github/workflows/test.yaml @@ -0,0 +1,63 @@ +name: PR Test + +on: + pull_request: + branches: [ main ] + workflow_dispatch: + +concurrency: + group: pr-test-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + unit-test: + if: (github.repository == 'sgl-project/SpecForge' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: [self-hosted] + container: + image: lmsysorg/sglang:v0.5.5 # we lock to this version to avoid repeated docker pull + options: --gpus all --shm-size=2g --rm -v /dev/shm + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Restore cache + run: | + if [ -d /github/home/cache ] && [ ! -z "$(ls -A /github/home/cache/)" ]; then + cp -p -r /github/home/cache ./ + fi + + if [ -d /github/home/sf ] && [ ! -z "$(ls -A /github/home/sf/)" ]; then + cp -p -r /github/home/sf ./ + fi + + - name: Remove flashinfer # this is needed to avoid flashinfer jit compilation makes the program hang + run: | + rm -rf /github/home/.cache/flashinfer + + - name: Install dependencies + shell: bash + run: | + # if sf venv does not exist, create it + if [ ! -d sf ]; then + uv venv sf -p 3.11 + fi + source sf/bin/activate + uv pip install setuptools + MAX_JOBS=8 uv pip install -v ".[fa]" --prerelease=allow --no-build-isolation + + - name: Run test + timeout-minutes: 30 + shell: bash + run: | + source sf/bin/activate + export PYTHONPATH=$PWD + python -m unittest discover -s ./tests -p "test_*.py" -v + + - name: Save cache + run: | + cp -p -r sf /github/home/ + cp -p -r cache /github/home/ diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/__grp__triton_red_fused_mul_0.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/__grp__triton_red_fused_mul_0.json new file mode 100644 index 0000000000000000000000000000000000000000..378f80bc5197a9aecb9d5dcb4b34df27f4261c71 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/__grp__triton_red_fused_mul_0.json @@ -0,0 +1 @@ +{"child_paths": {"triton_red_fused_mul_0.source": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.source", "triton_red_fused_mul_0.ttir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttir", "triton_red_fused_mul_0.ttgir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttgir", "triton_red_fused_mul_0.llir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.llir", "triton_red_fused_mul_0.ptx": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ptx", "triton_red_fused_mul_0.cubin": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.cubin", "triton_red_fused_mul_0.json": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.json"}} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.cubin b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..b8ba986b055409b781c13175a1fefe8bd76c31c1 Binary files /dev/null and b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.cubin differ diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.json new file mode 100644 index 0000000000000000000000000000000000000000..6f603cadda6e5bdea8247cfecff2d4f91315b551 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.json @@ -0,0 +1 @@ +{"hash": "f652b1e1a00143d965686269366c11027adbc930286c9a31fdc990ed50e76db4", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 256, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_mul_0"} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.llir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.llir new file mode 100644 index 0000000000000000000000000000000000000000..158e1dbe70bb3813e7807a4b42f1847e49102612 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.llir @@ -0,0 +1,161 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16 + +; Function Attrs: nounwind +define ptx_kernel void @triton_red_fused_mul_0(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, i32 %4, i32 %5, ptr addrspace(1) readnone captures(none) %6, ptr addrspace(1) readnone captures(none) %7) local_unnamed_addr #0 !dbg !4 { + %9 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !7 + %10 = shl i32 %9, 6, !dbg !8 + %11 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !9 + %12 = and i32 %11, 252, !dbg !9 + %13 = lshr exact i32 %12, 2, !dbg !9 + %14 = or disjoint i32 %13, %10, !dbg !10 + %15 = icmp slt i32 %14, 31232, !dbg !11 + %16 = and i32 %11, 3, !dbg !12 + %17 = sdiv i32 %14, 976, !dbg !13 + %18 = shl i32 %14, 7, !dbg !14 + %19 = shl i32 %14, 12 + %20 = mul i32 %17, -3997568 + %21 = add i32 %20, %19 + %22 = zext nneg i32 %16 to i64, !dbg !15 + %23 = sext i32 %18 to i64, !dbg !15 + %invariant.gep8 = getelementptr bfloat, ptr addrspace(1) %1, i64 %23, !dbg !15 + br i1 %15, label %.split.us, label %.split + +.split.us: ; preds = %8, %.split.us + %indvars.iv5 = phi i64 [ %indvars.iv.next6, %.split.us ], [ 0, %8 ] + %24 = phi float [ %39, %.split.us ], [ 0.000000e+00, %8 ] + %25 = or disjoint i64 %indvars.iv5, %22, !dbg !16 + %26 = trunc nuw nsw i64 %25 to i32, !dbg !17 + %27 = add i32 %21, %26, !dbg !17 + %28 = sext i32 %27 to i64, !dbg !18 + %29 = getelementptr bfloat, ptr addrspace(1) %0, i64 %28, !dbg !18 + %30 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !19 + %31 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %29, i64 %30, i1 true) #4, !dbg !19 + %32 = bitcast i16 %31 to bfloat, !dbg !19 + %33 = fpext bfloat %32 to float, !dbg !20 + %gep9 = getelementptr bfloat, ptr addrspace(1) %invariant.gep8, i64 %25, !dbg !21 + %34 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !22 + %35 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep9, i64 %34, i1 true) #4, !dbg !22 + %36 = bitcast i16 %35 to bfloat, !dbg !22 + %37 = fpext bfloat %36 to float, !dbg !23 + %38 = fmul float %33, %37, !dbg !24 + %39 = fadd float %24, %38, !dbg !25 + %indvars.iv.next6 = add nuw nsw i64 %indvars.iv5, 4, !dbg !15 + %40 = icmp samesign ult i64 %indvars.iv5, 124, !dbg !15 + br i1 %40, label %.split.us, label %.split2.us, !dbg !15 + +.split: ; preds = %8, %.split + %indvars.iv = phi i64 [ %indvars.iv.next, %.split ], [ 0, %8 ] + %41 = or disjoint i64 %indvars.iv, %22, !dbg !16 + %42 = trunc nuw nsw i64 %41 to i32, !dbg !17 + %43 = add i32 %21, %42, !dbg !17 + %44 = sext i32 %43 to i64, !dbg !18 + %45 = getelementptr bfloat, ptr addrspace(1) %0, i64 %44, !dbg !18 + %46 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !19 + %47 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %45, i64 %46, i1 false) #4, !dbg !19 + %gep = getelementptr bfloat, ptr addrspace(1) %invariant.gep8, i64 %41, !dbg !21 + %48 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_first.b64 $0, 1.0;", "=l"() #4, !dbg !22 + %49 = tail call i16 asm sideeffect "mov.u16 $0, $1;\0A\09@$4 ld.global.L1::evict_first.L2::cache_hint.b16 { $0 }, [ $2 + 0 ], $3;", "=c,c,l,l,b"(i16 0, ptr addrspace(1) %gep, i64 %48, i1 false) #4, !dbg !22 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 4, !dbg !15 + %50 = icmp samesign ult i64 %indvars.iv, 124, !dbg !15 + br i1 %50, label %.split, label %.split2.us, !dbg !15 + +.split2.us: ; preds = %.split, %.split.us + %.us-phi = phi float [ %39, %.split.us ], [ 0.000000e+00, %.split ], !dbg !9 + %51 = and i32 %11, 63, !dbg !9 + %52 = or disjoint i32 %10, %51, !dbg !10 + %53 = icmp slt i32 %52, 31232, !dbg !11 + %54 = bitcast float %.us-phi to i32, !dbg !26 + %55 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %54, i32 2, i32 31), !dbg !26 + %56 = bitcast i32 %55 to float, !dbg !26 + %57 = fadd float %.us-phi, %56, !dbg !30 + %58 = bitcast float %57 to i32, !dbg !26 + %59 = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 %58, i32 1, i32 31), !dbg !26 + %60 = bitcast i32 %59 to float, !dbg !26 + %61 = fadd float %57, %60, !dbg !30 + %62 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %12, !dbg !31 + store float %61, ptr addrspace(3) %62, align 4, !dbg !31 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !31 + %63 = shl nuw nsw i32 %51, 2, !dbg !31 + %64 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %63, !dbg !31 + %65 = load float, ptr addrspace(3) %64, align 4, !dbg !31 + %66 = sext i32 %52 to i64, !dbg !32 + %67 = getelementptr float, ptr addrspace(1) %2, i64 %66, !dbg !32 + %68 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #4, !dbg !33 + %69 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %67, i64 %68, i1 %53) #4, !dbg !33 + %70 = bitcast i32 %69 to float, !dbg !33 + %71 = fmul float %70, 0x3FE62E4300000000, !dbg !34 + %72 = fmul float %71, 0x3FF7154760000000, !dbg !35 + %73 = fsub float %65, %72, !dbg !31 + %74 = getelementptr float, ptr addrspace(1) %3, i64 %66, !dbg !36 + %75 = and i32 %11, 192, !dbg !37 + %76 = icmp eq i32 %75, 0, !dbg !37 + %77 = bitcast float %73 to i32, !dbg !37 + %78 = and i1 %76, %53, !dbg !37 + tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %77, ptr addrspace(1) %74, i1 %78) #4, !dbg !37 + ret void, !dbg !38 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #2 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #3 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #3 = { convergent nocallback nounwind } +attributes #4 = { nounwind } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py", directory: "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = distinct !DISubprogram(name: "triton_red_fused_mul_0", linkageName: "triton_red_fused_mul_0", scope: !1, file: !1, line: 18, type: !5, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!5 = !DISubroutineType(cc: DW_CC_normal, types: !6) +!6 = !{} +!7 = !DILocation(line: 23, column: 28, scope: !4) +!8 = !DILocation(line: 23, column: 33, scope: !4) +!9 = !DILocation(line: 24, column: 44, scope: !4) +!10 = !DILocation(line: 24, column: 23, scope: !4) +!11 = !DILocation(line: 25, column: 21, scope: !4) +!12 = !DILocation(line: 26, column: 37, scope: !4) +!13 = !DILocation(line: 29, column: 19, scope: !4) +!14 = !DILocation(line: 39, column: 45, scope: !4) +!15 = !DILocation(line: 32, column: 40, scope: !4) +!16 = !DILocation(line: 33, column: 31, scope: !4) +!17 = !DILocation(line: 38, column: 50, scope: !4) +!18 = !DILocation(line: 38, column: 34, scope: !4) +!19 = !DILocation(line: 38, column: 60, scope: !4) +!20 = !DILocation(line: 38, column: 122, scope: !4) +!21 = !DILocation(line: 39, column: 34, scope: !4) +!22 = !DILocation(line: 39, column: 50, scope: !4) +!23 = !DILocation(line: 39, column: 112, scope: !4) +!24 = !DILocation(line: 40, column: 22, scope: !4) +!25 = !DILocation(line: 42, column: 23, scope: !4) +!26 = !DILocation(line: 291, column: 36, scope: !27, inlinedAt: !29) +!27 = distinct !DILexicalBlockFile(scope: !4, file: !28, discriminator: 0) +!28 = !DIFile(filename: "standard.py", directory: "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language") +!29 = !DILocation(line: 44, column: 25, scope: !4) +!30 = !DILocation(line: 261, column: 15, scope: !27, inlinedAt: !29) +!31 = !DILocation(line: 51, column: 19, scope: !4) +!32 = !DILocation(line: 45, column: 30, scope: !4) +!33 = !DILocation(line: 45, column: 35, scope: !4) +!34 = !DILocation(line: 48, column: 18, scope: !4) +!35 = !DILocation(line: 50, column: 19, scope: !4) +!36 = !DILocation(line: 52, column: 25, scope: !4) +!37 = !DILocation(line: 52, column: 37, scope: !4) +!38 = !DILocation(line: 52, column: 4, scope: !4) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ptx b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ptx new file mode 100644 index 0000000000000000000000000000000000000000..6734bb78eb0ddeb6fad1cb3384e54451cab32848 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ptx @@ -0,0 +1,458 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_red_fused_mul_0 // -- Begin function triton_red_fused_mul_0 +.extern .shared .align 16 .b8 global_smem[]; + // @triton_red_fused_mul_0 +.visible .entry triton_red_fused_mul_0( + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_0, + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_1, + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_2, + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_3, + .param .u32 triton_red_fused_mul_0_param_4, + .param .u32 triton_red_fused_mul_0_param_5, + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_6, + .param .u64 .ptr .global .align 1 triton_red_fused_mul_0_param_7 +) +.reqntid 256 +{ + .reg .pred %p<11>; + .reg .b16 %rs<9>; + .reg .b32 %r<55>; + .reg .b64 %rd<50>; + .loc 1 18 0 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:18:0 + +// %bb.0: + ld.param.b64 %rd18, [triton_red_fused_mul_0_param_3]; + ld.param.b64 %rd17, [triton_red_fused_mul_0_param_2]; + ld.param.b64 %rd16, [triton_red_fused_mul_0_param_1]; + ld.param.b64 %rd15, [triton_red_fused_mul_0_param_0]; +$L__tmp0: + .loc 1 23 28 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:23:28 + mov.u32 %r1, %ctaid.x; + .loc 1 23 33 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:23:33 + shl.b32 %r2, %r1, 6; + .loc 1 24 44 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:24:44 + mov.u32 %r3, %tid.x; + and.b32 %r4, %r3, 252; + bfe.u32 %r5, %r3, 2, 6; + .loc 1 24 23 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:24:23 + or.b32 %r10, %r5, %r2; + .loc 1 25 21 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:25:21 + setp.lt.s32 %p1, %r10, 31232; + .loc 1 26 37 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:26:37 + and.b32 %r11, %r3, 3; + .loc 1 29 19 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:29:19 + mul.hi.s32 %r12, %r10, 1126548799; + shr.u32 %r13, %r12, 31; + shr.s32 %r14, %r12, 8; + add.s32 %r6, %r14, %r13; + .loc 1 39 45 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:39:45 + shl.b32 %r15, %r10, 7; + .loc 1 32 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:32:40 + cvt.u64.u32 %rd1, %r11; + cvt.s64.s32 %rd2, %r15; + @%p1 bra $L__BB0_3; + bra.uni $L__BB0_1; +$L__BB0_3: // %.split.us.preheader + .loc 1 0 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:0:40 + cvt.u32.u64 %r27, %rd1; + .loc 1 32 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:32:40 + shl.b64 %rd31, %rd2, 1; + shl.b64 %rd32, %rd1, 1; + add.s64 %rd33, %rd31, %rd32; + add.s64 %rd46, %rd16, %rd33; + shl.b32 %r28, %r1, 18; + shl.b32 %r29, %r5, 12; + or.b32 %r30, %r28, %r29; + or.b32 %r31, %r30, %r27; + mul.lo.s32 %r32, %r6, 3997568; + sub.s32 %r33, %r31, %r32; + cvt.u64.u32 %rd6, %r33; + mov.b32 %r54, 0f00000000; + mov.b64 %rd47, -4; +$L__BB0_4: // %.split.us + // =>This Inner Loop Header: Depth=1 + .loc 1 38 34 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:38:34 + add.s64 %rd40, %rd6, %rd47; + cvt.u32.u64 %r34, %rd40; + add.s32 %r35, %r34, 4; + mad.wide.s32 %rd35, %r35, 2, %rd15; + .loc 1 38 60 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:38:60 + // begin inline asm + mov.u64 %rd34, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd34, 1.0; + // end inline asm + mov.b16 %rs6, 0; + mov.pred %p5, -1; + // begin inline asm + mov.u16 %rs5, %rs6; + @%p5 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs5 }, [ %rd35 + 0 ], %rd34; + // end inline asm + .loc 1 38 122 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:38:122 + cvt.f32.bf16 %r36, %rs5; + .loc 1 39 50 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:39:50 + // begin inline asm + mov.u64 %rd37, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd37, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs7, %rs6; + @%p5 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs7 }, [ %rd46 + 0 ], %rd37; + // end inline asm + .loc 1 39 112 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:39:112 + cvt.f32.bf16 %r37, %rs7; + .loc 1 42 23 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:42:23 + fma.rn.f32 %r54, %r36, %r37, %r54; + .loc 1 32 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:32:40 + add.s64 %rd47, %rd47, 4; + add.s64 %rd46, %rd46, 8; + setp.lt.u64 %p7, %rd47, 124; + @%p7 bra $L__BB0_4; + bra.uni $L__BB0_5; +$L__BB0_1: // %.split.preheader + .loc 1 0 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:0:40 + cvt.u32.u64 %r16, %rd1; + .loc 1 32 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:32:40 + shl.b64 %rd20, %rd2, 1; + shl.b64 %rd21, %rd1, 1; + add.s64 %rd22, %rd20, %rd21; + add.s64 %rd48, %rd16, %rd22; + shl.b32 %r17, %r1, 18; + shl.b32 %r18, %r5, 12; + or.b32 %r19, %r17, %r18; + or.b32 %r20, %r19, %r16; + mul.lo.s32 %r21, %r6, 3997568; + sub.s32 %r22, %r20, %r21; + cvt.u64.u32 %rd4, %r22; + mov.b64 %rd49, -4; +$L__BB0_2: // %.split + // =>This Inner Loop Header: Depth=1 + .loc 1 38 34 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:38:34 + add.s64 %rd29, %rd4, %rd49; + cvt.u32.u64 %r24, %rd29; + add.s32 %r25, %r24, 4; + mad.wide.s32 %rd24, %r25, 2, %rd15; + .loc 1 38 60 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:38:60 + // begin inline asm + mov.u64 %rd23, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd23, 1.0; + // end inline asm + mov.b16 %rs2, 0; + mov.pred %p2, 0; + // begin inline asm + mov.u16 %rs1, %rs2; + @%p2 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs1 }, [ %rd24 + 0 ], %rd23; + // end inline asm + .loc 1 39 50 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:39:50 + // begin inline asm + mov.u64 %rd26, 0x0; + createpolicy.fractional.L2::evict_first.b64 %rd26, 1.0; + // end inline asm + // begin inline asm + mov.u16 %rs3, %rs2; + @%p2 ld.global.L1::evict_first.L2::cache_hint.b16 { %rs3 }, [ %rd48 + 0 ], %rd26; + // end inline asm + .loc 1 32 40 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:32:40 + add.s64 %rd49, %rd49, 4; + add.s64 %rd48, %rd48, 8; + setp.lt.u64 %p4, %rd49, 124; + mov.b32 %r54, 0f00000000; + @%p4 bra $L__BB0_2; +$L__BB0_5: // %.split2.us + .loc 1 24 44 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:24:44 + and.b32 %r40, %r3, 63; + .loc 1 24 23 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:24:23 + or.b32 %r41, %r2, %r40; + .loc 1 25 21 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:25:21 + setp.lt.s32 %p8, %r41, 31232; +$L__tmp1: + .loc 2 291 36 // standard.py:291:36 @[ cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:44:25 ] + shfl.sync.bfly.b32 %r42, %r54, 2, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:44:25 ] + add.f32 %r43, %r54, %r42; + .loc 2 291 36 // standard.py:291:36 @[ cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:44:25 ] + shfl.sync.bfly.b32 %r44, %r43, 1, 31, -1; + .loc 2 261 15 // standard.py:261:15 @[ cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:44:25 ] + add.f32 %r45, %r43, %r44; +$L__tmp2: + .loc 1 51 19 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:51:19 + mov.b32 %r46, global_smem; + add.s32 %r47, %r46, %r4; + st.shared.b32 [%r47], %r45; + bar.sync 0; + shl.b32 %r48, %r40, 2; + add.s32 %r49, %r46, %r48; + ld.shared.b32 %r50, [%r49]; + .loc 1 45 30 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:45:30 + mul.wide.s32 %rd45, %r41, 4; + add.s64 %rd42, %rd17, %rd45; + .loc 1 45 35 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:45:35 + // begin inline asm + mov.u64 %rd43, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd43, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r38, 0x0; + @%p8 ld.global.L1::evict_last.L2::cache_hint.b32 { %r38 }, [ %rd42 + 0 ], %rd43; + // end inline asm + .loc 1 48 18 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:48:18 + mul.f32 %r51, %r38, 0fBF317218; + .loc 1 51 19 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:51:19 + fma.rn.f32 %r39, %r51, 0f3FB8AA3B, %r50; + .loc 1 52 25 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:52:25 + add.s64 %rd44, %rd18, %rd45; + .loc 1 52 37 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:52:37 + and.b32 %r52, %r3, 192; + setp.eq.b32 %p10, %r52, 0; + and.pred %p9, %p10, %p8; + // begin inline asm + @%p9 st.global.b32 [ %rd44 + 0 ], { %r39 }; + // end inline asm + .loc 1 52 4 // cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py:52:4 + ret; +$L__tmp3: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py" + .file 2 "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 211 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0xcc DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 112 +.b8 115 +.b8 110 +.b8 112 +.b8 55 +.b8 112 +.b8 101 +.b8 99 +.b8 120 +.b8 102 +.b8 101 +.b8 120 +.b8 108 +.b8 53 +.b8 100 +.b8 107 +.b8 101 +.b8 116 +.b8 114 +.b8 101 +.b8 118 +.b8 108 +.b8 103 +.b8 116 +.b8 102 +.b8 117 +.b8 119 +.b8 112 +.b8 112 +.b8 119 +.b8 113 +.b8 117 +.b8 107 +.b8 115 +.b8 51 +.b8 55 +.b8 108 +.b8 50 +.b8 120 +.b8 102 +.b8 121 +.b8 97 +.b8 99 +.b8 109 +.b8 112 +.b8 52 +.b8 114 +.b8 116 +.b8 55 +.b8 97 +.b8 53 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 106 +.b8 117 +.b8 110 +.b8 113 +.b8 117 +.b8 97 +.b8 110 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 112 +.b8 115 +.b8 0 +.b8 2 // Abbrev [2] 0x8f:0x19 DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 114 +.b8 101 +.b8 100 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 109 +.b8 117 +.b8 108 +.b8 95 +.b8 48 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa8:0x2e DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 143 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbd:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 44 // DW_AT_call_line +.b8 25 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.source b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.source new file mode 100644 index 0000000000000000000000000000000000000000..d67ea6a280ff158306492e88a00e112319fbe2a0 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.source @@ -0,0 +1,230 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":18:0) +#loc48 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":285:0) +#loc50 = loc(unknown) +#loc53 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":260:0) +#loc57 = loc("in_ptr0"(#loc)) +#loc58 = loc("in_ptr1"(#loc)) +#loc59 = loc("in_ptr2"(#loc)) +#loc60 = loc("out_ptr1"(#loc)) +#loc61 = loc("xnumel"(#loc)) +#loc62 = loc("r0_numel"(#loc)) +#loc106 = loc("input"(#loc48)) +#loc107 = loc("a"(#loc53)) +#loc108 = loc("b"(#loc53)) +module { + tt.func public @triton_red_fused_mul_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 31232 : i32 loc(#loc63) + %r0_numel_1 = arith.constant 128 : i32 loc(#loc64) + %xoffset = tt.get_program_id x : i32 loc(#loc65) + %xoffset_2 = arith.constant 64 : i32 loc(#loc66) + %xoffset_3 = arith.constant 64 : i32 loc(#loc66) + %xoffset_4 = arith.muli %xoffset, %xoffset_3 : i32 loc(#loc66) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc67) + %xindex_5 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc68) + %xindex_6 = tt.splat %xoffset_4 : i32 -> tensor<64x1xi32> loc(#loc69) + %xindex_7 = arith.addi %xindex_6, %xindex_5 : tensor<64x1xi32> loc(#loc69) + %xmask = arith.constant dense<31232> : tensor<64x1xi32> loc(#loc70) + %xmask_8 = arith.cmpi slt, %xindex_7, %xmask : tensor<64x1xi32> loc(#loc70) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc71) + %r0_base_9 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc72) + %x0 = arith.constant 976 : i32 loc(#loc73) + %x0_10 = arith.constant 976 : i32 loc(#loc73) + %x0_11 = arith.constant dense<976> : tensor<64x1xi32> loc(#loc73) + %x0_12 = arith.remsi %xindex_7, %x0_11 : tensor<64x1xi32> loc(#loc73) + %x1 = arith.constant 976 : i32 loc(#loc74) + %x1_13 = arith.constant 976 : i32 loc(#loc74) + %x1_14 = arith.constant dense<976> : tensor<64x1xi32> loc(#loc74) + %x1_15 = arith.divsi %xindex_7, %x1_14 : tensor<64x1xi32> loc(#loc74) + %_tmp4 = arith.constant 0.000000e+00 : f32 loc(#loc75) + %_tmp4_16 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc75) + %c0_i32 = arith.constant 0 : i32 loc(#loc14) + %c4_i32 = arith.constant 4 : i32 loc(#loc14) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc14) + %1 = arith.bitcast %r0_numel_1 : i32 to i32 loc(#loc14) + %2 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc14) + %3 = ub.poison : i32 loc(#loc14) + %_tmp4_17 = scf.for %r0_offset = %0 to %1 step %2 iter_args(%_tmp4_23 = %_tmp4_16) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc77) + %r0_index_24 = arith.addi %r0_index, %r0_base_9 : tensor<1x4xi32> loc(#loc77) + %r0_mask = arith.constant dense<128> : tensor<1x4xi32> loc(#loc78) + %r0_mask_25 = arith.cmpi slt, %r0_index_24, %r0_mask : tensor<1x4xi32> loc(#loc78) + %tmp0 = arith.constant 128 : i32 loc(#loc79) + %tmp0_26 = arith.constant 128 : i32 loc(#loc79) + %tmp0_27 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc79) + %tmp0_28 = arith.muli %tmp0_27, %x1_15 : tensor<64x1xi32> loc(#loc79) + %tmp0_29 = tt.broadcast %r0_index_24 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc80) + %tmp0_30 = tt.broadcast %tmp0_28 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc80) + %tmp0_31 = arith.addi %tmp0_29, %tmp0_30 : tensor<64x4xi32> loc(#loc80) + %tmp0_32 = arith.constant 4096 : i32 loc(#loc81) + %tmp0_33 = arith.constant 4096 : i32 loc(#loc81) + %tmp0_34 = arith.constant dense<4096> : tensor<64x1xi32> loc(#loc81) + %tmp0_35 = arith.muli %tmp0_34, %x0_12 : tensor<64x1xi32> loc(#loc81) + %tmp0_36 = tt.broadcast %tmp0_35 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc82) + %tmp0_37 = arith.addi %tmp0_31, %tmp0_36 : tensor<64x4xi32> loc(#loc82) + %tmp0_38 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc83) + %tmp0_39 = tt.addptr %tmp0_38, %tmp0_37 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc83) + %tmp0_40 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc84) + %tmp0_41 = tt.broadcast %xmask_8 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc84) + %tmp0_42 = arith.andi %tmp0_40, %tmp0_41 : tensor<64x4xi1> loc(#loc84) + %tmp0_43 = arith.constant 0.000000e+00 : f32 loc(#loc85) + %tmp0_44 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc85) + %tmp0_45 = arith.truncf %tmp0_44 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc85) + %tmp0_46 = tt.load %tmp0_39, %tmp0_42, %tmp0_45 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc85) + %tmp0_47 = arith.extf %tmp0_46 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc86) + %tmp1 = arith.constant 128 : i32 loc(#loc87) + %tmp1_48 = arith.constant 128 : i32 loc(#loc87) + %tmp1_49 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc87) + %tmp1_50 = arith.muli %tmp1_49, %xindex_7 : tensor<64x1xi32> loc(#loc87) + %tmp1_51 = tt.broadcast %r0_index_24 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc88) + %tmp1_52 = tt.broadcast %tmp1_50 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc88) + %tmp1_53 = arith.addi %tmp1_51, %tmp1_52 : tensor<64x4xi32> loc(#loc88) + %tmp1_54 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc89) + %tmp1_55 = tt.addptr %tmp1_54, %tmp1_53 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc89) + %tmp1_56 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc90) + %tmp1_57 = tt.broadcast %xmask_8 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc90) + %tmp1_58 = arith.andi %tmp1_56, %tmp1_57 : tensor<64x4xi1> loc(#loc90) + %tmp1_59 = arith.constant 0.000000e+00 : f32 loc(#loc91) + %tmp1_60 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc91) + %tmp1_61 = arith.truncf %tmp1_60 : tensor<64x4xf32> to tensor<64x4xbf16> loc(#loc91) + %tmp1_62 = tt.load %tmp1_55, %tmp1_58, %tmp1_61 evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc91) + %tmp1_63 = arith.extf %tmp1_62 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc92) + %tmp2 = arith.mulf %tmp0_47, %tmp1_63 : tensor<64x4xf32> loc(#loc93) + %tmp5 = arith.addf %_tmp4_23, %tmp2 : tensor<64x4xf32> loc(#loc94) + %_tmp4_64 = tt.broadcast %r0_mask_25 : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc95) + %_tmp4_65 = tt.broadcast %xmask_8 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc95) + %_tmp4_66 = arith.andi %_tmp4_64, %_tmp4_65 : tensor<64x4xi1> loc(#loc95) + %_tmp4_67 = arith.select %_tmp4_66, %tmp5, %_tmp4_23 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc96) + scf.yield %_tmp4_67 : tensor<64x4xf32> loc(#loc35) + } loc(#loc76) + %tmp4 = tt.call @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_tmp4_17) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc97) + %tmp4_18 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc98) + %tmp7 = tt.splat %in_ptr2 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc99) + %tmp7_19 = tt.addptr %tmp7, %xindex_7 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc99) + %tmp7_20 = tt.load %tmp7_19, %xmask_8 evictionPolicy = evict_last : tensor<64x1x!tt.ptr> loc(#loc100) + %tmp8 = arith.constant 0.693147182 : f32 loc(#loc101) + %tmp9 = arith.constant dense<0.693147182> : tensor<64x1xf32> loc(#loc102) + %tmp9_21 = arith.mulf %tmp7_20, %tmp9 : tensor<64x1xf32> loc(#loc102) + %tmp10 = arith.constant 1.44269502 : f32 loc(#loc103) + %tmp11 = arith.constant dense<1.44269502> : tensor<64x1xf32> loc(#loc104) + %tmp11_22 = arith.mulf %tmp9_21, %tmp11 : tensor<64x1xf32> loc(#loc104) + %tmp12 = arith.subf %tmp4_18, %tmp11_22 : tensor<64x1xf32> loc(#loc105) + %4 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc45) + %5 = tt.addptr %4, %xindex_7 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc45) + tt.store %5, %tmp12, %xmask_8 : tensor<64x1x!tt.ptr> loc(#loc46) + tt.return loc(#loc47) + } loc(#loc) + tt.func private @"triton.language.standard.sum__fp32S64_4S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<64x4xf32> loc("input"(#loc48))) -> tensor<64xf32> attributes {noinline = false} { + %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): + %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc49) + tt.reduce.return %2 : f32 loc(#loc49) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc49) + tt.return %0 : tensor<64xf32> loc(#loc51) + ^bb1: // no predecessors + %1 = ub.poison : tensor<64xf32> loc(#loc52) + tt.return %1 : tensor<64xf32> loc(#loc52) + } loc(#loc48) + tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32 loc("a"(#loc53)), %b: f32 loc("b"(#loc53))) -> f32 attributes {noinline = false} { + %0 = arith.addf %a, %b : f32 loc(#loc54) + tt.return %0 : f32 loc(#loc55) + ^bb1: // no predecessors + %1 = ub.poison : f32 loc(#loc56) + tt.return %1 : f32 loc(#loc56) + } loc(#loc53) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":19:13) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":20:15) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:28) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:33) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:36) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:44) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:23) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":25:21) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":26:27) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":26:37) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":28:19) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":29:19) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":31:43) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":32:40) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":33:31) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":34:29) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:45) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:41) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:55) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:50) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:34) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:70) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:60) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:122) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:45) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:41) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:34) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:60) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:50) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:112) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":40:22) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":42:23) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:35) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:48) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:8) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:25) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:28) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:30) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:35) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":47:11) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":48:18) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":49:12) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":50:19) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":51:19) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:25) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:37) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:4) +#loc49 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc51 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:11) +#loc52 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:4) +#loc54 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc55 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:11) +#loc56 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:4) +#loc63 = loc("xnumel"(#loc1)) +#loc64 = loc("r0_numel"(#loc2)) +#loc65 = loc("xoffset"(#loc3)) +#loc66 = loc("xoffset"(#loc4)) +#loc67 = loc("xindex"(#loc5)) +#loc68 = loc("xindex"(#loc6)) +#loc69 = loc("xindex"(#loc7)) +#loc70 = loc("xmask"(#loc8)) +#loc71 = loc("r0_base"(#loc9)) +#loc72 = loc("r0_base"(#loc10)) +#loc73 = loc("x0"(#loc11)) +#loc74 = loc("x1"(#loc12)) +#loc75 = loc("_tmp4"(#loc13)) +#loc76 = loc("_tmp4"(#loc14)) +#loc77 = loc("r0_index"(#loc15)) +#loc78 = loc("r0_mask"(#loc16)) +#loc79 = loc("tmp0"(#loc17)) +#loc80 = loc("tmp0"(#loc18)) +#loc81 = loc("tmp0"(#loc19)) +#loc82 = loc("tmp0"(#loc20)) +#loc83 = loc("tmp0"(#loc21)) +#loc84 = loc("tmp0"(#loc22)) +#loc85 = loc("tmp0"(#loc23)) +#loc86 = loc("tmp0"(#loc24)) +#loc87 = loc("tmp1"(#loc25)) +#loc88 = loc("tmp1"(#loc26)) +#loc89 = loc("tmp1"(#loc27)) +#loc90 = loc("tmp1"(#loc28)) +#loc91 = loc("tmp1"(#loc29)) +#loc92 = loc("tmp1"(#loc30)) +#loc93 = loc("tmp2"(#loc31)) +#loc94 = loc("tmp5"(#loc32)) +#loc95 = loc("_tmp4"(#loc33)) +#loc96 = loc("_tmp4"(#loc34)) +#loc97 = loc("tmp4"(#loc36)) +#loc98 = loc("tmp4"(#loc37)) +#loc99 = loc("tmp7"(#loc38)) +#loc100 = loc("tmp7"(#loc39)) +#loc101 = loc("tmp8"(#loc40)) +#loc102 = loc("tmp9"(#loc41)) +#loc103 = loc("tmp10"(#loc42)) +#loc104 = loc("tmp11"(#loc43)) +#loc105 = loc("tmp12"(#loc44)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttgir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..e8b8bdcaeae394b44f559f8ea56d0a39913d1822 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttgir @@ -0,0 +1,168 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":18:0) +#loc1 = loc(unknown) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:25) +#loc42 = loc("in_ptr0"(#loc)) +#loc43 = loc("in_ptr1"(#loc)) +#loc44 = loc("in_ptr2"(#loc)) +#loc45 = loc("out_ptr1"(#loc)) +#loc46 = loc("xnumel"(#loc)) +#loc47 = loc("r0_numel"(#loc)) +#loc75 = loc("tmp4"(#loc31)) +#loc83 = loc(callsite(#loc1 at #loc75)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_mul_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.693147182> : tensor<64x1xf32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<1.44269502> : tensor<64x1xf32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<31232> : tensor<64x1xi32, #blocked1> loc(#loc1) + %cst_2 = arith.constant dense<976> : tensor<64x1xi32, #blocked1> loc(#loc1) + %cst_3 = arith.constant dense<128> : tensor<64x1xi32, #blocked1> loc(#loc1) + %cst_4 = arith.constant dense<4096> : tensor<64x1xi32, #blocked1> loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<64x4xf32, #blocked1> loc(#loc1) + %cst_6 = arith.constant dense<128> : tensor<1x4xi32, #blocked1> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x4xbf16, #blocked1> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_8 = arith.constant dense<31232> : tensor<64x1xi32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc48) + %xoffset_9 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc49) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc50) + %xindex_10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc50) + %xindex_11 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> loc(#loc50) + %xindex_12 = tt.expand_dims %xindex_10 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc50) + %xindex_13 = tt.splat %xoffset_9 : i32 -> tensor<64x1xi32, #blocked1> loc(#loc51) + %xindex_14 = tt.splat %xoffset_9 : i32 -> tensor<64x1xi32, #blocked> loc(#loc51) + %xindex_15 = arith.addi %xindex_13, %xindex_11 : tensor<64x1xi32, #blocked1> loc(#loc51) + %xindex_16 = arith.addi %xindex_14, %xindex_12 : tensor<64x1xi32, #blocked> loc(#loc51) + %xmask = arith.cmpi slt, %xindex_15, %cst_1 : tensor<64x1xi32, #blocked1> loc(#loc52) + %xmask_17 = arith.cmpi slt, %xindex_16, %cst_8 : tensor<64x1xi32, #blocked> loc(#loc52) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc53) + %r0_base_18 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x4xi32, #blocked1> loc(#loc53) + %x0 = arith.remsi %xindex_15, %cst_2 : tensor<64x1xi32, #blocked1> loc(#loc54) + %x1 = arith.divsi %xindex_15, %cst_2 : tensor<64x1xi32, #blocked1> loc(#loc55) + %tmp0 = arith.muli %x1, %cst_3 : tensor<64x1xi32, #blocked1> loc(#loc56) + %tmp0_19 = tt.broadcast %tmp0 : tensor<64x1xi32, #blocked1> -> tensor<64x4xi32, #blocked1> loc(#loc57) + %tmp0_20 = arith.muli %x0, %cst_4 : tensor<64x1xi32, #blocked1> loc(#loc58) + %tmp0_21 = tt.broadcast %tmp0_20 : tensor<64x1xi32, #blocked1> -> tensor<64x4xi32, #blocked1> loc(#loc59) + %tmp0_22 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked1> loc(#loc60) + %tmp0_23 = tt.broadcast %xmask : tensor<64x1xi1, #blocked1> -> tensor<64x4xi1, #blocked1> loc(#loc61) + %tmp1 = arith.muli %xindex_15, %cst_3 : tensor<64x1xi32, #blocked1> loc(#loc62) + %tmp1_24 = tt.broadcast %tmp1 : tensor<64x1xi32, #blocked1> -> tensor<64x4xi32, #blocked1> loc(#loc63) + %tmp1_25 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr, #blocked1> loc(#loc64) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_30 = %cst_5) -> (tensor<64x4xf32, #blocked1>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32, #blocked1> loc(#loc66) + %r0_index_31 = arith.addi %r0_index, %r0_base_18 : tensor<1x4xi32, #blocked1> loc(#loc66) + %r0_mask = arith.cmpi slt, %r0_index_31, %cst_6 : tensor<1x4xi32, #blocked1> loc(#loc67) + %tmp0_32 = tt.broadcast %r0_index_31 : tensor<1x4xi32, #blocked1> -> tensor<64x4xi32, #blocked1> loc(#loc57) + %tmp0_33 = arith.addi %tmp0_32, %tmp0_19 : tensor<64x4xi32, #blocked1> loc(#loc57) + %tmp0_34 = arith.addi %tmp0_33, %tmp0_21 : tensor<64x4xi32, #blocked1> loc(#loc59) + %tmp0_35 = tt.addptr %tmp0_22, %tmp0_34 : tensor<64x4x!tt.ptr, #blocked1>, tensor<64x4xi32, #blocked1> loc(#loc60) + %tmp0_36 = tt.broadcast %r0_mask : tensor<1x4xi1, #blocked1> -> tensor<64x4xi1, #blocked1> loc(#loc61) + %tmp0_37 = arith.andi %tmp0_36, %tmp0_23 : tensor<64x4xi1, #blocked1> loc(#loc61) + %tmp0_38 = tt.load %tmp0_35, %tmp0_37, %cst_7 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked1> loc(#loc68) + %tmp0_39 = arith.extf %tmp0_38 : tensor<64x4xbf16, #blocked1> to tensor<64x4xf32, #blocked1> loc(#loc69) + %tmp1_40 = arith.addi %tmp0_32, %tmp1_24 : tensor<64x4xi32, #blocked1> loc(#loc63) + %tmp1_41 = tt.addptr %tmp1_25, %tmp1_40 : tensor<64x4x!tt.ptr, #blocked1>, tensor<64x4xi32, #blocked1> loc(#loc64) + %tmp1_42 = tt.load %tmp1_41, %tmp0_37, %cst_7 evictionPolicy = evict_first : tensor<64x4x!tt.ptr, #blocked1> loc(#loc70) + %tmp1_43 = arith.extf %tmp1_42 : tensor<64x4xbf16, #blocked1> to tensor<64x4xf32, #blocked1> loc(#loc71) + %tmp2 = arith.mulf %tmp0_39, %tmp1_43 : tensor<64x4xf32, #blocked1> loc(#loc72) + %tmp5 = arith.addf %_tmp4_30, %tmp2 : tensor<64x4xf32, #blocked1> loc(#loc73) + %_tmp4_44 = arith.select %tmp0_37, %tmp5, %_tmp4_30 : tensor<64x4xi1, #blocked1>, tensor<64x4xf32, #blocked1> loc(#loc74) + scf.yield %_tmp4_44 : tensor<64x4xf32, #blocked1> loc(#loc29) + } loc(#loc65) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_30: f32 loc(callsite(#loc1 at #loc75)), %tmp4_31: f32 loc(callsite(#loc1 at #loc75))): + %tmp4_32 = arith.addf %tmp4_30, %tmp4_31 : f32 loc(#loc84) + tt.reduce.return %tmp4_32 : f32 loc(#loc82) + }) : (tensor<64x4xf32, #blocked1>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc82) + %tmp12 = ttg.convert_layout %tmp4 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc76) + %tmp4_26 = tt.expand_dims %tmp12 {axis = 1 : i32} : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xf32, #blocked> loc(#loc77) + %tmp7 = tt.splat %in_ptr2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> loc(#loc78) + %tmp7_27 = tt.addptr %tmp7, %xindex_16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc78) + %tmp7_28 = tt.load %tmp7_27, %xmask_17 evictionPolicy = evict_last : tensor<64x1x!tt.ptr, #blocked> loc(#loc79) + %tmp9 = arith.mulf %tmp7_28, %cst : tensor<64x1xf32, #blocked> loc(#loc80) + %tmp11 = arith.mulf %tmp9, %cst_0 : tensor<64x1xf32, #blocked> loc(#loc81) + %tmp12_29 = arith.subf %tmp4_26, %tmp11 : tensor<64x1xf32, #blocked> loc(#loc76) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> loc(#loc39) + %1 = tt.addptr %0, %xindex_16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc39) + tt.store %1, %tmp12_29, %xmask_17 : tensor<64x1x!tt.ptr, #blocked> loc(#loc40) + tt.return loc(#loc41) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:28) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:33) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:44) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:23) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":25:21) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":26:37) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":28:19) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":29:19) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:45) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:41) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:55) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:50) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:34) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:70) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:45) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:41) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:34) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":32:40) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":33:31) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":34:29) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:60) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:122) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:50) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:112) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":40:22) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":42:23) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:48) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:8) +#loc30 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc32 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":51:19) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:28) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:30) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:35) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":48:18) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":50:19) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:25) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:37) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:4) +#loc48 = loc("xoffset"(#loc2)) +#loc49 = loc("xoffset"(#loc3)) +#loc50 = loc("xindex"(#loc4)) +#loc51 = loc("xindex"(#loc5)) +#loc52 = loc("xmask"(#loc6)) +#loc53 = loc("r0_base"(#loc7)) +#loc54 = loc("x0"(#loc8)) +#loc55 = loc("x1"(#loc9)) +#loc56 = loc("tmp0"(#loc10)) +#loc57 = loc("tmp0"(#loc11)) +#loc58 = loc("tmp0"(#loc12)) +#loc59 = loc("tmp0"(#loc13)) +#loc60 = loc("tmp0"(#loc14)) +#loc61 = loc("tmp0"(#loc15)) +#loc62 = loc("tmp1"(#loc16)) +#loc63 = loc("tmp1"(#loc17)) +#loc64 = loc("tmp1"(#loc18)) +#loc65 = loc("_tmp4"(#loc19)) +#loc66 = loc("r0_index"(#loc20)) +#loc67 = loc("r0_mask"(#loc21)) +#loc68 = loc("tmp0"(#loc22)) +#loc69 = loc("tmp0"(#loc23)) +#loc70 = loc("tmp1"(#loc24)) +#loc71 = loc("tmp1"(#loc25)) +#loc72 = loc("tmp2"(#loc26)) +#loc73 = loc("tmp5"(#loc27)) +#loc74 = loc("_tmp4"(#loc28)) +#loc76 = loc("tmp12"(#loc33)) +#loc77 = loc("tmp4"(#loc34)) +#loc78 = loc("tmp7"(#loc35)) +#loc79 = loc("tmp7"(#loc36)) +#loc80 = loc("tmp9"(#loc37)) +#loc81 = loc("tmp11"(#loc38)) +#loc82 = loc(callsite(#loc30 at #loc75)) +#loc84 = loc(callsite(#loc32 at #loc82)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttir new file mode 100644 index 0000000000000000000000000000000000000000..e845b589c2370ff3c7e9c102f2266607646bffc8 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/6ZJLDYNAAFB5SZLIMJUTM3ARAJ5NXSJQFBWJUMP5ZGIO2UHHNW2A/triton_red_fused_mul_0.ttir @@ -0,0 +1,163 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":18:0) +#loc1 = loc(unknown) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:25) +#loc44 = loc("in_ptr0"(#loc)) +#loc45 = loc("in_ptr1"(#loc)) +#loc46 = loc("in_ptr2"(#loc)) +#loc47 = loc("out_ptr1"(#loc)) +#loc48 = loc("xnumel"(#loc)) +#loc49 = loc("r0_numel"(#loc)) +#loc81 = loc("tmp4"(#loc35)) +#loc87 = loc(callsite(#loc1 at #loc81)) +module { + tt.func public @triton_red_fused_mul_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x4xbf16> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc2) + %c128_i32 = arith.constant 128 : i32 loc(#loc2) + %c0_i32 = arith.constant 0 : i32 loc(#loc2) + %tmp11 = arith.constant dense<1.44269502> : tensor<64x1xf32> loc(#loc50) + %tmp9 = arith.constant dense<0.693147182> : tensor<64x1xf32> loc(#loc51) + %cst_0 = arith.constant dense<4096> : tensor<64x1xi32> loc(#loc1) + %cst_1 = arith.constant dense<128> : tensor<64x1xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<1x4xi32> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x4xf32> loc(#loc1) + %cst_4 = arith.constant dense<976> : tensor<64x1xi32> loc(#loc1) + %xmask = arith.constant dense<31232> : tensor<64x1xi32> loc(#loc52) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc53) + %xoffset_5 = arith.muli %xoffset, %c64_i32 : i32 loc(#loc54) + %xindex = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc55) + %xindex_6 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc56) + %xindex_7 = tt.splat %xoffset_5 : i32 -> tensor<64x1xi32> loc(#loc57) + %xindex_8 = arith.addi %xindex_7, %xindex_6 : tensor<64x1xi32> loc(#loc57) + %xmask_9 = arith.cmpi slt, %xindex_8, %xmask : tensor<64x1xi32> loc(#loc52) + %r0_base = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> loc(#loc58) + %r0_base_10 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> loc(#loc59) + %x0 = arith.remsi %xindex_8, %cst_4 : tensor<64x1xi32> loc(#loc60) + %x1 = arith.divsi %xindex_8, %cst_4 : tensor<64x1xi32> loc(#loc61) + %_tmp4 = scf.for %r0_offset = %c0_i32 to %c128_i32 step %c4_i32 iter_args(%_tmp4_16 = %cst_3) -> (tensor<64x4xf32>) : i32 { + %r0_index = tt.splat %r0_offset : i32 -> tensor<1x4xi32> loc(#loc63) + %r0_index_17 = arith.addi %r0_index, %r0_base_10 : tensor<1x4xi32> loc(#loc63) + %r0_mask = arith.cmpi slt, %r0_index_17, %cst_2 : tensor<1x4xi32> loc(#loc64) + %tmp0 = arith.muli %x1, %cst_1 : tensor<64x1xi32> loc(#loc65) + %tmp0_18 = tt.broadcast %r0_index_17 : tensor<1x4xi32> -> tensor<64x4xi32> loc(#loc66) + %tmp0_19 = tt.broadcast %tmp0 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc66) + %tmp0_20 = arith.addi %tmp0_18, %tmp0_19 : tensor<64x4xi32> loc(#loc66) + %tmp0_21 = arith.muli %x0, %cst_0 : tensor<64x1xi32> loc(#loc67) + %tmp0_22 = tt.broadcast %tmp0_21 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc68) + %tmp0_23 = arith.addi %tmp0_20, %tmp0_22 : tensor<64x4xi32> loc(#loc68) + %tmp0_24 = tt.splat %in_ptr0 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc69) + %tmp0_25 = tt.addptr %tmp0_24, %tmp0_23 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc69) + %tmp0_26 = tt.broadcast %r0_mask : tensor<1x4xi1> -> tensor<64x4xi1> loc(#loc70) + %tmp0_27 = tt.broadcast %xmask_9 : tensor<64x1xi1> -> tensor<64x4xi1> loc(#loc70) + %tmp0_28 = arith.andi %tmp0_26, %tmp0_27 : tensor<64x4xi1> loc(#loc70) + %tmp0_29 = tt.load %tmp0_25, %tmp0_28, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc71) + %tmp0_30 = arith.extf %tmp0_29 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc72) + %tmp1 = arith.muli %xindex_8, %cst_1 : tensor<64x1xi32> loc(#loc73) + %tmp1_31 = tt.broadcast %tmp1 : tensor<64x1xi32> -> tensor<64x4xi32> loc(#loc74) + %tmp1_32 = arith.addi %tmp0_18, %tmp1_31 : tensor<64x4xi32> loc(#loc74) + %tmp1_33 = tt.splat %in_ptr1 : !tt.ptr -> tensor<64x4x!tt.ptr> loc(#loc75) + %tmp1_34 = tt.addptr %tmp1_33, %tmp1_32 : tensor<64x4x!tt.ptr>, tensor<64x4xi32> loc(#loc75) + %tmp1_35 = tt.load %tmp1_34, %tmp0_28, %cst evictionPolicy = evict_first : tensor<64x4x!tt.ptr> loc(#loc76) + %tmp1_36 = arith.extf %tmp1_35 : tensor<64x4xbf16> to tensor<64x4xf32> loc(#loc77) + %tmp2 = arith.mulf %tmp0_30, %tmp1_36 : tensor<64x4xf32> loc(#loc78) + %tmp5 = arith.addf %_tmp4_16, %tmp2 : tensor<64x4xf32> loc(#loc79) + %_tmp4_37 = arith.select %tmp0_28, %tmp5, %_tmp4_16 : tensor<64x4xi1>, tensor<64x4xf32> loc(#loc80) + scf.yield %_tmp4_37 : tensor<64x4xf32> loc(#loc33) + } loc(#loc62) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_16: f32 loc(callsite(#loc1 at #loc81)), %tmp4_17: f32 loc(callsite(#loc1 at #loc81))): + %tmp4_18 = arith.addf %tmp4_16, %tmp4_17 : f32 loc(#loc88) + tt.reduce.return %tmp4_18 : f32 loc(#loc86) + }) : (tensor<64x4xf32>) -> tensor<64xf32> loc(#loc86) + %tmp4_11 = tt.expand_dims %tmp4 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32> loc(#loc82) + %tmp7 = tt.splat %in_ptr2 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc83) + %tmp7_12 = tt.addptr %tmp7, %xindex_8 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc83) + %tmp7_13 = tt.load %tmp7_12, %xmask_9 evictionPolicy = evict_last : tensor<64x1x!tt.ptr> loc(#loc84) + %tmp9_14 = arith.mulf %tmp7_13, %tmp9 : tensor<64x1xf32> loc(#loc51) + %tmp11_15 = arith.mulf %tmp9_14, %tmp11 : tensor<64x1xf32> loc(#loc50) + %tmp12 = arith.subf %tmp4_11, %tmp11_15 : tensor<64x1xf32> loc(#loc85) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc41) + %1 = tt.addptr %0, %xindex_8 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc41) + tt.store %1, %tmp12, %xmask_9 : tensor<64x1x!tt.ptr> loc(#loc42) + tt.return loc(#loc43) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":32:40) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":50:19) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":48:18) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":25:21) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:28) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":23:33) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:36) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:44) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":24:23) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":26:27) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":26:37) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":28:19) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":29:19) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":33:31) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":34:29) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:45) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:41) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:55) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:50) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:34) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:70) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:60) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":38:122) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:45) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:41) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:34) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:50) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":39:112) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":40:22) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":42:23) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:48) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":43:8) +#loc34 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc36 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":44:28) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:30) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":45:35) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":51:19) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:25) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:37) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/ps/cpsnp7pecxfexl5dketrevlgtfuwppwquks37l2xfyacmp4rt7a5.py":52:4) +#loc50 = loc("tmp11"(#loc3)) +#loc51 = loc("tmp9"(#loc4)) +#loc52 = loc("xmask"(#loc5)) +#loc53 = loc("xoffset"(#loc6)) +#loc54 = loc("xoffset"(#loc7)) +#loc55 = loc("xindex"(#loc8)) +#loc56 = loc("xindex"(#loc9)) +#loc57 = loc("xindex"(#loc10)) +#loc58 = loc("r0_base"(#loc11)) +#loc59 = loc("r0_base"(#loc12)) +#loc60 = loc("x0"(#loc13)) +#loc61 = loc("x1"(#loc14)) +#loc62 = loc("_tmp4"(#loc2)) +#loc63 = loc("r0_index"(#loc15)) +#loc64 = loc("r0_mask"(#loc16)) +#loc65 = loc("tmp0"(#loc17)) +#loc66 = loc("tmp0"(#loc18)) +#loc67 = loc("tmp0"(#loc19)) +#loc68 = loc("tmp0"(#loc20)) +#loc69 = loc("tmp0"(#loc21)) +#loc70 = loc("tmp0"(#loc22)) +#loc71 = loc("tmp0"(#loc23)) +#loc72 = loc("tmp0"(#loc24)) +#loc73 = loc("tmp1"(#loc25)) +#loc74 = loc("tmp1"(#loc26)) +#loc75 = loc("tmp1"(#loc27)) +#loc76 = loc("tmp1"(#loc28)) +#loc77 = loc("tmp1"(#loc29)) +#loc78 = loc("tmp2"(#loc30)) +#loc79 = loc("tmp5"(#loc31)) +#loc80 = loc("_tmp4"(#loc32)) +#loc82 = loc("tmp4"(#loc37)) +#loc83 = loc("tmp7"(#loc38)) +#loc84 = loc("tmp7"(#loc39)) +#loc85 = loc("tmp12"(#loc40)) +#loc86 = loc(callsite(#loc34 at #loc81)) +#loc88 = loc(callsite(#loc36 at #loc86)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/__grp__triton_tem_fused_mul_1.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/__grp__triton_tem_fused_mul_1.json new file mode 100644 index 0000000000000000000000000000000000000000..ef1af22a2d7f7aa1ef48a0b13e1769d75cf15d5d --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/__grp__triton_tem_fused_mul_1.json @@ -0,0 +1 @@ +{"child_paths": {"triton_tem_fused_mul_1.source": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.source", "triton_tem_fused_mul_1.ttir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttir", "triton_tem_fused_mul_1.ttgir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttgir", "triton_tem_fused_mul_1.llir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.llir", "triton_tem_fused_mul_1.ptx": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ptx", "triton_tem_fused_mul_1.cubin": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.cubin", "triton_tem_fused_mul_1.json": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.json"}} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.json new file mode 100644 index 0000000000000000000000000000000000000000..99d670bd7b66040a06205898e7870230d9beac7b --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.json @@ -0,0 +1 @@ +{"hash": "fdf02639e61a052085bc6f3d4b3b415d01ac9f98c2dda2f65525dae22fcdaa63", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 3, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 164864, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_tem_fused_mul_1"} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.llir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.llir new file mode 100644 index 0000000000000000000000000000000000000000..644b673d5535d68f9cb61877912955f65bd8640c --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.llir @@ -0,0 +1,14128 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 +@.str = private unnamed_addr constant [11 x i8] c"__CUDA_FTZ\00", align 1 + +; Function Attrs: nounwind +define ptx_kernel void @triton_tem_fused_mul_1(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, ptr addrspace(1) %7, ptr addrspace(1) %8, ptr addrspace(1) %9, ptr addrspace(1) %10, ptr addrspace(1) %11, ptr addrspace(1) %12, ptr addrspace(1) %13, ptr addrspace(1) %14, ptr addrspace(1) %15, ptr addrspace(1) %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, ptr addrspace(1) readnone captures(none) %25, ptr addrspace(1) readnone captures(none) %26) local_unnamed_addr #0 !dbg !5 { + %28 = shl i32 %17, 12, !dbg !8 + %29 = icmp slt i32 %17, 2, !dbg !9 + %30 = zext i1 %29 to i32, !dbg !10 + %31 = icmp sgt i32 %17, 1, !dbg !11 + %32 = select i1 %31, i32 %17, i32 0, !dbg !12 + %33 = add i32 %32, %30, !dbg !13 + %34 = shl i32 %33, 12, !dbg !14 + %35 = shl i32 %33, 7, !dbg !15 + %36 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !16 + %37 = add i32 %18, 127, !dbg !17 + %38 = sdiv i32 %37, 128, !dbg !21 + %39 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !22 + %40 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !dbg !23 + %41 = shl nuw nsw i32 %40, 7, !dbg !24 + %42 = zext nneg i32 %41 to i64, !dbg !25 + %43 = shl nuw nsw i32 %39, 10, !dbg !26 + %44 = mul i32 %43, %18, !dbg !27 + %45 = add i32 %44, %41, !dbg !28 + %46 = sext i32 %45 to i64, !dbg !29 + %47 = getelementptr bfloat, ptr addrspace(1) %1, i64 %42, !dbg !30 + %48 = getelementptr bfloat, ptr addrspace(1) %2, i64 %42, !dbg !31 + %49 = getelementptr bfloat, ptr addrspace(1) %7, i64 %46, !dbg !32 + %50 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !33 + %51 = lshr i32 %50, 5, !dbg !33 + %52 = and i32 %50, 240, !dbg !33 + %53 = lshr exact i32 %52, 4, !dbg !33 + %54 = or disjoint i32 %53, 16, !dbg !33 + %55 = or disjoint i32 %53, 32, !dbg !33 + %56 = or disjoint i32 %53, 48, !dbg !33 + %57 = or disjoint i32 %53, 64, !dbg !33 + %58 = or disjoint i32 %53, 80, !dbg !33 + %59 = or disjoint i32 %53, 96, !dbg !33 + %60 = or disjoint i32 %53, 112, !dbg !33 + %61 = lshr i32 %50, 1, !dbg !33 + %62 = and i32 %61, 112, !dbg !33 + %63 = lshr i32 %50, 2, !dbg !33 + %64 = and i32 %63, 7, !dbg !33 + %65 = or disjoint i32 %62, %64, !dbg !33 + %66 = or disjoint i32 %65, 8, !dbg !33 + %.not = icmp slt i32 %36, %38, !dbg !34 + br i1 %.not, label %4740, label %67, !dbg !35 + +67: ; preds = %27 + %68 = add i32 %17, 127, !dbg !36 + %69 = sdiv i32 %68, 128, !dbg !38 + %70 = sub i32 %36, %38, !dbg !39 + %.frozen = freeze i32 %70, !dbg !40 + %.frozen4126 = freeze i32 %69, !dbg !40 + %71 = sdiv i32 %.frozen, %.frozen4126, !dbg !40 + %72 = shl nuw nsw i32 %40, 2, !dbg !41 + %73 = add i32 %71, %72, !dbg !42 + %74 = mul i32 %71, %.frozen4126, !dbg !43 + %.decomposed = sub i32 %.frozen, %74, !dbg !43 + %75 = mul i32 %.decomposed, %21, !dbg !44 + %76 = shl i32 %73, 7, !dbg !45 + %77 = mul i32 %28, %39, !dbg !46 + %78 = add i32 %76, %77, !dbg !47 + %79 = sext i32 %78 to i64, !dbg !48 + %80 = mul i32 %73, %35, !dbg !49 + %81 = mul i32 %34, %39, !dbg !50 + %82 = add i32 %80, %81, !dbg !51 + %83 = sext i32 %82 to i64, !dbg !52 + %84 = shl nuw nsw i32 %39, 5, !dbg !53 + %85 = add i32 %73, %84, !dbg !54 + %86 = mul i32 %85, %17, !dbg !55 + %87 = sext i32 %86 to i64, !dbg !56 + %88 = getelementptr bfloat, ptr addrspace(1) %0, i64 %79, !dbg !57 + %89 = getelementptr bfloat, ptr addrspace(1) %5, i64 %83, !dbg !58 + %90 = getelementptr bfloat, ptr addrspace(1) %6, i64 %79, !dbg !59 + %91 = getelementptr float, ptr addrspace(1) %3, i64 %87, !dbg !60 + %92 = getelementptr float, ptr addrspace(1) %4, i64 %87, !dbg !61 + %93 = shl nsw i32 %.decomposed, 7, !dbg !62 + %94 = or disjoint i32 %93, %53, !dbg !63 + %95 = or disjoint i32 %93, %54, !dbg !63 + %96 = or disjoint i32 %93, %55, !dbg !63 + %97 = or disjoint i32 %93, %56, !dbg !63 + %98 = or disjoint i32 %93, %57, !dbg !63 + %99 = or disjoint i32 %93, %58, !dbg !63 + %100 = or disjoint i32 %93, %59, !dbg !63 + %101 = or disjoint i32 %93, %60, !dbg !63 + %102 = or disjoint i32 %93, %65, !dbg !63 + %103 = or disjoint i32 %93, %66, !dbg !63 + %104 = shl i32 %94, 12, !dbg !64 + %105 = shl i32 %95, 12, !dbg !64 + %106 = shl i32 %96, 12, !dbg !64 + %107 = shl i32 %97, 12, !dbg !64 + %108 = shl i32 %98, 12, !dbg !64 + %109 = shl i32 %99, 12, !dbg !64 + %110 = shl i32 %100, 12, !dbg !64 + %111 = shl i32 %101, 12, !dbg !64 + %112 = sext i32 %104 to i64, !dbg !67 + %113 = getelementptr bfloat, ptr addrspace(1) %88, i64 %112, !dbg !67 + %114 = sext i32 %105 to i64, !dbg !67 + %115 = getelementptr bfloat, ptr addrspace(1) %88, i64 %114, !dbg !67 + %116 = sext i32 %106 to i64, !dbg !67 + %117 = getelementptr bfloat, ptr addrspace(1) %88, i64 %116, !dbg !67 + %118 = sext i32 %107 to i64, !dbg !67 + %119 = getelementptr bfloat, ptr addrspace(1) %88, i64 %118, !dbg !67 + %120 = sext i32 %108 to i64, !dbg !67 + %121 = getelementptr bfloat, ptr addrspace(1) %88, i64 %120, !dbg !67 + %122 = sext i32 %109 to i64, !dbg !67 + %123 = getelementptr bfloat, ptr addrspace(1) %88, i64 %122, !dbg !67 + %124 = sext i32 %110 to i64, !dbg !67 + %125 = getelementptr bfloat, ptr addrspace(1) %88, i64 %124, !dbg !67 + %126 = sext i32 %111 to i64, !dbg !67 + %127 = getelementptr bfloat, ptr addrspace(1) %88, i64 %126, !dbg !67 + %128 = shl nuw nsw i32 %50, 3, !dbg !68 + %129 = and i32 %128, 120, !dbg !68 + %130 = zext nneg i32 %129 to i64, !dbg !69 + %131 = getelementptr bfloat, ptr addrspace(1) %113, i64 %130, !dbg !69 + %132 = getelementptr bfloat, ptr addrspace(1) %115, i64 %130, !dbg !69 + %133 = getelementptr bfloat, ptr addrspace(1) %117, i64 %130, !dbg !69 + %134 = getelementptr bfloat, ptr addrspace(1) %119, i64 %130, !dbg !69 + %135 = getelementptr bfloat, ptr addrspace(1) %121, i64 %130, !dbg !69 + %136 = getelementptr bfloat, ptr addrspace(1) %123, i64 %130, !dbg !69 + %137 = getelementptr bfloat, ptr addrspace(1) %125, i64 %130, !dbg !69 + %138 = getelementptr bfloat, ptr addrspace(1) %127, i64 %130, !dbg !69 + %139 = icmp slt i32 %94, %17, !dbg !70 + %140 = icmp slt i32 %95, %17, !dbg !70 + %141 = icmp slt i32 %96, %17, !dbg !70 + %142 = icmp slt i32 %97, %17, !dbg !70 + %143 = icmp slt i32 %98, %17, !dbg !70 + %144 = icmp slt i32 %99, %17, !dbg !70 + %145 = icmp slt i32 %100, %17, !dbg !70 + %146 = icmp slt i32 %101, %17, !dbg !70 + %147 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %131, i1 %139) #3, !dbg !71 + %148 = extractvalue { i32, i32, i32, i32 } %147, 0, !dbg !71 + %149 = extractvalue { i32, i32, i32, i32 } %147, 1, !dbg !71 + %150 = extractvalue { i32, i32, i32, i32 } %147, 2, !dbg !71 + %151 = extractvalue { i32, i32, i32, i32 } %147, 3, !dbg !71 + %152 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %132, i1 %140) #3, !dbg !71 + %153 = extractvalue { i32, i32, i32, i32 } %152, 0, !dbg !71 + %154 = extractvalue { i32, i32, i32, i32 } %152, 1, !dbg !71 + %155 = extractvalue { i32, i32, i32, i32 } %152, 2, !dbg !71 + %156 = extractvalue { i32, i32, i32, i32 } %152, 3, !dbg !71 + %157 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %133, i1 %141) #3, !dbg !71 + %158 = extractvalue { i32, i32, i32, i32 } %157, 0, !dbg !71 + %159 = extractvalue { i32, i32, i32, i32 } %157, 1, !dbg !71 + %160 = extractvalue { i32, i32, i32, i32 } %157, 2, !dbg !71 + %161 = extractvalue { i32, i32, i32, i32 } %157, 3, !dbg !71 + %162 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %134, i1 %142) #3, !dbg !71 + %163 = extractvalue { i32, i32, i32, i32 } %162, 0, !dbg !71 + %164 = extractvalue { i32, i32, i32, i32 } %162, 1, !dbg !71 + %165 = extractvalue { i32, i32, i32, i32 } %162, 2, !dbg !71 + %166 = extractvalue { i32, i32, i32, i32 } %162, 3, !dbg !71 + %167 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %135, i1 %143) #3, !dbg !71 + %168 = extractvalue { i32, i32, i32, i32 } %167, 0, !dbg !71 + %169 = extractvalue { i32, i32, i32, i32 } %167, 1, !dbg !71 + %170 = extractvalue { i32, i32, i32, i32 } %167, 2, !dbg !71 + %171 = extractvalue { i32, i32, i32, i32 } %167, 3, !dbg !71 + %172 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %136, i1 %144) #3, !dbg !71 + %173 = extractvalue { i32, i32, i32, i32 } %172, 0, !dbg !71 + %174 = extractvalue { i32, i32, i32, i32 } %172, 1, !dbg !71 + %175 = extractvalue { i32, i32, i32, i32 } %172, 2, !dbg !71 + %176 = extractvalue { i32, i32, i32, i32 } %172, 3, !dbg !71 + %177 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %137, i1 %145) #3, !dbg !71 + %178 = extractvalue { i32, i32, i32, i32 } %177, 0, !dbg !71 + %179 = extractvalue { i32, i32, i32, i32 } %177, 1, !dbg !71 + %180 = extractvalue { i32, i32, i32, i32 } %177, 2, !dbg !71 + %181 = extractvalue { i32, i32, i32, i32 } %177, 3, !dbg !71 + %182 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %138, i1 %146) #3, !dbg !71 + %183 = extractvalue { i32, i32, i32, i32 } %182, 0, !dbg !71 + %184 = extractvalue { i32, i32, i32, i32 } %182, 1, !dbg !71 + %185 = extractvalue { i32, i32, i32, i32 } %182, 2, !dbg !71 + %186 = extractvalue { i32, i32, i32, i32 } %182, 3, !dbg !71 + %187 = shl nuw nsw i32 %50, 4, !dbg !71 + %188 = and i32 %187, 112, !dbg !71 + %189 = shl nuw nsw i32 %52, 3, !dbg !71 + %190 = and i32 %50, 112, !dbg !71 + %191 = and i32 %50, 8, !dbg !71 + %192 = shl nuw nsw i32 %191, 11, !dbg !71 + %193 = or disjoint i32 %188, %189, !dbg !71 + %194 = xor i32 %193, %190, !dbg !71 + %195 = or disjoint i32 %194, %192, !dbg !71 + %196 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %195, !dbg !71 + %197 = insertelement <4 x i32> poison, i32 %148, i64 0, !dbg !71 + %198 = insertelement <4 x i32> %197, i32 %149, i64 1, !dbg !71 + %199 = insertelement <4 x i32> %198, i32 %150, i64 2, !dbg !71 + %200 = insertelement <4 x i32> %199, i32 %151, i64 3, !dbg !71 + store <4 x i32> %200, ptr addrspace(3) %196, align 16, !dbg !71 + %201 = or disjoint i32 %195, 2048, !dbg !71 + %202 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %201, !dbg !71 + %203 = insertelement <4 x i32> poison, i32 %153, i64 0, !dbg !71 + %204 = insertelement <4 x i32> %203, i32 %154, i64 1, !dbg !71 + %205 = insertelement <4 x i32> %204, i32 %155, i64 2, !dbg !71 + %206 = insertelement <4 x i32> %205, i32 %156, i64 3, !dbg !71 + store <4 x i32> %206, ptr addrspace(3) %202, align 16, !dbg !71 + %207 = or disjoint i32 %195, 4096, !dbg !71 + %208 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %207, !dbg !71 + %209 = insertelement <4 x i32> poison, i32 %158, i64 0, !dbg !71 + %210 = insertelement <4 x i32> %209, i32 %159, i64 1, !dbg !71 + %211 = insertelement <4 x i32> %210, i32 %160, i64 2, !dbg !71 + %212 = insertelement <4 x i32> %211, i32 %161, i64 3, !dbg !71 + store <4 x i32> %212, ptr addrspace(3) %208, align 16, !dbg !71 + %213 = or disjoint i32 %195, 6144, !dbg !71 + %214 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %213, !dbg !71 + %215 = insertelement <4 x i32> poison, i32 %163, i64 0, !dbg !71 + %216 = insertelement <4 x i32> %215, i32 %164, i64 1, !dbg !71 + %217 = insertelement <4 x i32> %216, i32 %165, i64 2, !dbg !71 + %218 = insertelement <4 x i32> %217, i32 %166, i64 3, !dbg !71 + store <4 x i32> %218, ptr addrspace(3) %214, align 16, !dbg !71 + %219 = or disjoint i32 %195, 8192, !dbg !71 + %220 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %219, !dbg !71 + %221 = insertelement <4 x i32> poison, i32 %168, i64 0, !dbg !71 + %222 = insertelement <4 x i32> %221, i32 %169, i64 1, !dbg !71 + %223 = insertelement <4 x i32> %222, i32 %170, i64 2, !dbg !71 + %224 = insertelement <4 x i32> %223, i32 %171, i64 3, !dbg !71 + store <4 x i32> %224, ptr addrspace(3) %220, align 16, !dbg !71 + %225 = or disjoint i32 %195, 10240, !dbg !71 + %226 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %225, !dbg !71 + %227 = insertelement <4 x i32> poison, i32 %173, i64 0, !dbg !71 + %228 = insertelement <4 x i32> %227, i32 %174, i64 1, !dbg !71 + %229 = insertelement <4 x i32> %228, i32 %175, i64 2, !dbg !71 + %230 = insertelement <4 x i32> %229, i32 %176, i64 3, !dbg !71 + store <4 x i32> %230, ptr addrspace(3) %226, align 16, !dbg !71 + %231 = or disjoint i32 %195, 12288, !dbg !71 + %232 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %231, !dbg !71 + %233 = insertelement <4 x i32> poison, i32 %178, i64 0, !dbg !71 + %234 = insertelement <4 x i32> %233, i32 %179, i64 1, !dbg !71 + %235 = insertelement <4 x i32> %234, i32 %180, i64 2, !dbg !71 + %236 = insertelement <4 x i32> %235, i32 %181, i64 3, !dbg !71 + store <4 x i32> %236, ptr addrspace(3) %232, align 16, !dbg !71 + %237 = or disjoint i32 %195, 14336, !dbg !71 + %238 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %237, !dbg !71 + %239 = insertelement <4 x i32> poison, i32 %183, i64 0, !dbg !71 + %240 = insertelement <4 x i32> %239, i32 %184, i64 1, !dbg !71 + %241 = insertelement <4 x i32> %240, i32 %185, i64 2, !dbg !71 + %242 = insertelement <4 x i32> %241, i32 %186, i64 3, !dbg !71 + store <4 x i32> %242, ptr addrspace(3) %238, align 16, !dbg !71 + %243 = shl i32 %94, 7, !dbg !72 + %244 = shl i32 %95, 7, !dbg !72 + %245 = shl i32 %96, 7, !dbg !72 + %246 = shl i32 %97, 7, !dbg !72 + %247 = shl i32 %98, 7, !dbg !72 + %248 = shl i32 %99, 7, !dbg !72 + %249 = shl i32 %100, 7, !dbg !72 + %250 = shl i32 %101, 7, !dbg !72 + %251 = sext i32 %243 to i64, !dbg !74 + %252 = getelementptr bfloat, ptr addrspace(1) %89, i64 %251, !dbg !74 + %253 = sext i32 %244 to i64, !dbg !74 + %254 = getelementptr bfloat, ptr addrspace(1) %89, i64 %253, !dbg !74 + %255 = sext i32 %245 to i64, !dbg !74 + %256 = getelementptr bfloat, ptr addrspace(1) %89, i64 %255, !dbg !74 + %257 = sext i32 %246 to i64, !dbg !74 + %258 = getelementptr bfloat, ptr addrspace(1) %89, i64 %257, !dbg !74 + %259 = sext i32 %247 to i64, !dbg !74 + %260 = getelementptr bfloat, ptr addrspace(1) %89, i64 %259, !dbg !74 + %261 = sext i32 %248 to i64, !dbg !74 + %262 = getelementptr bfloat, ptr addrspace(1) %89, i64 %261, !dbg !74 + %263 = sext i32 %249 to i64, !dbg !74 + %264 = getelementptr bfloat, ptr addrspace(1) %89, i64 %263, !dbg !74 + %265 = sext i32 %250 to i64, !dbg !74 + %266 = getelementptr bfloat, ptr addrspace(1) %89, i64 %265, !dbg !74 + %267 = getelementptr bfloat, ptr addrspace(1) %252, i64 %130, !dbg !75 + %268 = getelementptr bfloat, ptr addrspace(1) %254, i64 %130, !dbg !75 + %269 = getelementptr bfloat, ptr addrspace(1) %256, i64 %130, !dbg !75 + %270 = getelementptr bfloat, ptr addrspace(1) %258, i64 %130, !dbg !75 + %271 = getelementptr bfloat, ptr addrspace(1) %260, i64 %130, !dbg !75 + %272 = getelementptr bfloat, ptr addrspace(1) %262, i64 %130, !dbg !75 + %273 = getelementptr bfloat, ptr addrspace(1) %264, i64 %130, !dbg !75 + %274 = getelementptr bfloat, ptr addrspace(1) %266, i64 %130, !dbg !75 + %275 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %267, i1 %139) #3, !dbg !76 + %276 = extractvalue { i32, i32, i32, i32 } %275, 0, !dbg !76 + %277 = extractvalue { i32, i32, i32, i32 } %275, 1, !dbg !76 + %278 = extractvalue { i32, i32, i32, i32 } %275, 2, !dbg !76 + %279 = extractvalue { i32, i32, i32, i32 } %275, 3, !dbg !76 + %280 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %268, i1 %140) #3, !dbg !76 + %281 = extractvalue { i32, i32, i32, i32 } %280, 0, !dbg !76 + %282 = extractvalue { i32, i32, i32, i32 } %280, 1, !dbg !76 + %283 = extractvalue { i32, i32, i32, i32 } %280, 2, !dbg !76 + %284 = extractvalue { i32, i32, i32, i32 } %280, 3, !dbg !76 + %285 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %269, i1 %141) #3, !dbg !76 + %286 = extractvalue { i32, i32, i32, i32 } %285, 0, !dbg !76 + %287 = extractvalue { i32, i32, i32, i32 } %285, 1, !dbg !76 + %288 = extractvalue { i32, i32, i32, i32 } %285, 2, !dbg !76 + %289 = extractvalue { i32, i32, i32, i32 } %285, 3, !dbg !76 + %290 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %270, i1 %142) #3, !dbg !76 + %291 = extractvalue { i32, i32, i32, i32 } %290, 0, !dbg !76 + %292 = extractvalue { i32, i32, i32, i32 } %290, 1, !dbg !76 + %293 = extractvalue { i32, i32, i32, i32 } %290, 2, !dbg !76 + %294 = extractvalue { i32, i32, i32, i32 } %290, 3, !dbg !76 + %295 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %271, i1 %143) #3, !dbg !76 + %296 = extractvalue { i32, i32, i32, i32 } %295, 0, !dbg !76 + %297 = extractvalue { i32, i32, i32, i32 } %295, 1, !dbg !76 + %298 = extractvalue { i32, i32, i32, i32 } %295, 2, !dbg !76 + %299 = extractvalue { i32, i32, i32, i32 } %295, 3, !dbg !76 + %300 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %272, i1 %144) #3, !dbg !76 + %301 = extractvalue { i32, i32, i32, i32 } %300, 0, !dbg !76 + %302 = extractvalue { i32, i32, i32, i32 } %300, 1, !dbg !76 + %303 = extractvalue { i32, i32, i32, i32 } %300, 2, !dbg !76 + %304 = extractvalue { i32, i32, i32, i32 } %300, 3, !dbg !76 + %305 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %273, i1 %145) #3, !dbg !76 + %306 = extractvalue { i32, i32, i32, i32 } %305, 0, !dbg !76 + %307 = extractvalue { i32, i32, i32, i32 } %305, 1, !dbg !76 + %308 = extractvalue { i32, i32, i32, i32 } %305, 2, !dbg !76 + %309 = extractvalue { i32, i32, i32, i32 } %305, 3, !dbg !76 + %310 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %274, i1 %146) #3, !dbg !76 + %311 = extractvalue { i32, i32, i32, i32 } %310, 0, !dbg !76 + %312 = extractvalue { i32, i32, i32, i32 } %310, 1, !dbg !76 + %313 = extractvalue { i32, i32, i32, i32 } %310, 2, !dbg !76 + %314 = extractvalue { i32, i32, i32, i32 } %310, 3, !dbg !76 + %315 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %195, !dbg !76 + %316 = insertelement <4 x i32> poison, i32 %276, i64 0, !dbg !76 + %317 = insertelement <4 x i32> %316, i32 %277, i64 1, !dbg !76 + %318 = insertelement <4 x i32> %317, i32 %278, i64 2, !dbg !76 + %319 = insertelement <4 x i32> %318, i32 %279, i64 3, !dbg !76 + store <4 x i32> %319, ptr addrspace(3) %315, align 16, !dbg !76 + %320 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %201, !dbg !76 + %321 = insertelement <4 x i32> poison, i32 %281, i64 0, !dbg !76 + %322 = insertelement <4 x i32> %321, i32 %282, i64 1, !dbg !76 + %323 = insertelement <4 x i32> %322, i32 %283, i64 2, !dbg !76 + %324 = insertelement <4 x i32> %323, i32 %284, i64 3, !dbg !76 + store <4 x i32> %324, ptr addrspace(3) %320, align 16, !dbg !76 + %325 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %207, !dbg !76 + %326 = insertelement <4 x i32> poison, i32 %286, i64 0, !dbg !76 + %327 = insertelement <4 x i32> %326, i32 %287, i64 1, !dbg !76 + %328 = insertelement <4 x i32> %327, i32 %288, i64 2, !dbg !76 + %329 = insertelement <4 x i32> %328, i32 %289, i64 3, !dbg !76 + store <4 x i32> %329, ptr addrspace(3) %325, align 16, !dbg !76 + %330 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %213, !dbg !76 + %331 = insertelement <4 x i32> poison, i32 %291, i64 0, !dbg !76 + %332 = insertelement <4 x i32> %331, i32 %292, i64 1, !dbg !76 + %333 = insertelement <4 x i32> %332, i32 %293, i64 2, !dbg !76 + %334 = insertelement <4 x i32> %333, i32 %294, i64 3, !dbg !76 + store <4 x i32> %334, ptr addrspace(3) %330, align 16, !dbg !76 + %335 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %219, !dbg !76 + %336 = insertelement <4 x i32> poison, i32 %296, i64 0, !dbg !76 + %337 = insertelement <4 x i32> %336, i32 %297, i64 1, !dbg !76 + %338 = insertelement <4 x i32> %337, i32 %298, i64 2, !dbg !76 + %339 = insertelement <4 x i32> %338, i32 %299, i64 3, !dbg !76 + store <4 x i32> %339, ptr addrspace(3) %335, align 16, !dbg !76 + %340 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %225, !dbg !76 + %341 = insertelement <4 x i32> poison, i32 %301, i64 0, !dbg !76 + %342 = insertelement <4 x i32> %341, i32 %302, i64 1, !dbg !76 + %343 = insertelement <4 x i32> %342, i32 %303, i64 2, !dbg !76 + %344 = insertelement <4 x i32> %343, i32 %304, i64 3, !dbg !76 + store <4 x i32> %344, ptr addrspace(3) %340, align 16, !dbg !76 + %345 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %231, !dbg !76 + %346 = insertelement <4 x i32> poison, i32 %306, i64 0, !dbg !76 + %347 = insertelement <4 x i32> %346, i32 %307, i64 1, !dbg !76 + %348 = insertelement <4 x i32> %347, i32 %308, i64 2, !dbg !76 + %349 = insertelement <4 x i32> %348, i32 %309, i64 3, !dbg !76 + store <4 x i32> %349, ptr addrspace(3) %345, align 16, !dbg !76 + %350 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %237, !dbg !76 + %351 = insertelement <4 x i32> poison, i32 %311, i64 0, !dbg !76 + %352 = insertelement <4 x i32> %351, i32 %312, i64 1, !dbg !76 + %353 = insertelement <4 x i32> %352, i32 %313, i64 2, !dbg !76 + %354 = insertelement <4 x i32> %353, i32 %314, i64 3, !dbg !76 + store <4 x i32> %354, ptr addrspace(3) %350, align 16, !dbg !76 + %355 = icmp slt i32 %102, %17, !dbg !77 + %356 = icmp slt i32 %103, %17, !dbg !77 + %357 = sext i32 %102 to i64, !dbg !78 + %358 = getelementptr float, ptr addrspace(1) %92, i64 %357, !dbg !78 + %359 = sext i32 %103 to i64, !dbg !78 + %360 = getelementptr float, ptr addrspace(1) %92, i64 %359, !dbg !78 + %361 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %358, i1 %355) #3, !dbg !79 + %362 = bitcast i32 %361 to float, !dbg !79 + %363 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %360, i1 %356) #3, !dbg !79 + %364 = bitcast i32 %363 to float, !dbg !79 + %365 = getelementptr float, ptr addrspace(1) %91, i64 %357, !dbg !80 + %366 = getelementptr float, ptr addrspace(1) %91, i64 %359, !dbg !80 + %367 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %365, i1 %355) #3, !dbg !81 + %368 = bitcast i32 %367 to float, !dbg !81 + %369 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %366, i1 %356) #3, !dbg !81 + %370 = bitcast i32 %369 to float, !dbg !81 + %371 = fcmp oeq float %368, 0xFFF0000000000000, !dbg !82 + %372 = fcmp oeq float %370, 0xFFF0000000000000, !dbg !82 + %373 = select i1 %371, float 0.000000e+00, float %368, !dbg !83 + %374 = select i1 %372, float 0.000000e+00, float %370, !dbg !83 + %375 = sext i32 %75 to i64, !dbg !84 + %376 = getelementptr i32, ptr addrspace(1) %9, i64 %375, !dbg !84 + %377 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %376) #3, !dbg !85 + %378 = shl i32 %377, 7, !dbg !86 + %379 = sext i32 %.decomposed to i64, !dbg !87 + %380 = getelementptr i32, ptr addrspace(1) %8, i64 %379, !dbg !87 + %381 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %380) #3, !dbg !88 + %382 = and i32 %50, 3, !dbg !89 + %383 = shl nuw nsw i32 %382, 1, !dbg !89 + %384 = or disjoint i32 %383, 1, !dbg !89 + %385 = insertelement <2 x i32> poison, i32 %383, i64 0, !dbg !89 + %386 = shufflevector <2 x i32> %385, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !89 + %387 = or disjoint <2 x i32> %386, , !dbg !89 + %388 = insertelement <4 x i32> poison, i32 %383, i64 0, !dbg !89 + %389 = shufflevector <4 x i32> %388, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !89 + %390 = or disjoint <4 x i32> %389, , !dbg !89 + %391 = insertelement <8 x i32> poison, i32 %383, i64 0, !dbg !89 + %392 = shufflevector <8 x i32> %391, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !89 + %393 = or disjoint <8 x i32> %392, , !dbg !89 + %394 = or disjoint i32 %378, %53, !dbg !90 + %395 = or disjoint i32 %378, %54, !dbg !90 + %396 = or disjoint i32 %378, %55, !dbg !90 + %397 = or disjoint i32 %378, %56, !dbg !90 + %398 = shl i32 %394, 10, !dbg !91 + %399 = shl i32 %395, 10, !dbg !91 + %400 = shl i32 %396, 10, !dbg !91 + %401 = shl i32 %397, 10, !dbg !91 + %402 = sext i32 %398 to i64, !dbg !93 + %403 = getelementptr bfloat, ptr addrspace(1) %47, i64 %402, !dbg !93 + %404 = sext i32 %399 to i64, !dbg !93 + %405 = getelementptr bfloat, ptr addrspace(1) %47, i64 %404, !dbg !93 + %406 = sext i32 %400 to i64, !dbg !93 + %407 = getelementptr bfloat, ptr addrspace(1) %47, i64 %406, !dbg !93 + %408 = sext i32 %401 to i64, !dbg !93 + %409 = getelementptr bfloat, ptr addrspace(1) %47, i64 %408, !dbg !93 + %410 = getelementptr bfloat, ptr addrspace(1) %403, i64 %130, !dbg !94 + %411 = getelementptr bfloat, ptr addrspace(1) %405, i64 %130, !dbg !94 + %412 = getelementptr bfloat, ptr addrspace(1) %407, i64 %130, !dbg !94 + %413 = getelementptr bfloat, ptr addrspace(1) %409, i64 %130, !dbg !94 + %414 = getelementptr bfloat, ptr addrspace(1) %48, i64 %402, !dbg !95 + %415 = getelementptr bfloat, ptr addrspace(1) %48, i64 %404, !dbg !95 + %416 = getelementptr bfloat, ptr addrspace(1) %48, i64 %406, !dbg !95 + %417 = getelementptr bfloat, ptr addrspace(1) %48, i64 %408, !dbg !95 + %418 = getelementptr bfloat, ptr addrspace(1) %414, i64 %130, !dbg !96 + %419 = getelementptr bfloat, ptr addrspace(1) %415, i64 %130, !dbg !96 + %420 = getelementptr bfloat, ptr addrspace(1) %416, i64 %130, !dbg !96 + %421 = getelementptr bfloat, ptr addrspace(1) %417, i64 %130, !dbg !96 + %422 = shl i32 %381, 1, !dbg !97 + %423 = add i32 %18, 63, !dbg !98 + %424 = sdiv i32 %423, 64, !dbg !99 + %425 = tail call i32 @llvm.smax.i32(i32 %424, i32 1), !dbg !100 + %426 = tail call i32 @llvm.smin.i32(i32 %422, i32 %425), !dbg !101 + %427 = icmp sgt i32 %422, 0, !dbg !102 + %428 = icmp slt i32 %394, %18, !dbg !103 + %429 = icmp slt i32 %395, %18, !dbg !103 + %430 = icmp slt i32 %396, %18, !dbg !103 + %431 = icmp slt i32 %397, %18, !dbg !103 + %432 = and i1 %427, %428, !dbg !102 + %433 = and i1 %427, %429, !dbg !102 + %434 = and i1 %427, %430, !dbg !102 + %435 = and i1 %427, %431, !dbg !102 + %436 = shl nuw nsw i32 %191, 10, !dbg !104 + %437 = or disjoint i32 %194, %436, !dbg !104 + %438 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %437, !dbg !104 + %439 = select i1 %432, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %438, ptr addrspace(1) %410, i32 %439) #3, !dbg !104 + %440 = or disjoint i32 %437, 2048, !dbg !104 + %441 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %440, !dbg !104 + %442 = select i1 %433, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %441, ptr addrspace(1) %411, i32 %442) #3, !dbg !104 + %443 = or disjoint i32 %437, 4096, !dbg !104 + %444 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %443, !dbg !104 + %445 = select i1 %434, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %444, ptr addrspace(1) %412, i32 %445) #3, !dbg !104 + %446 = or disjoint i32 %437, 6144, !dbg !104 + %447 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %446, !dbg !104 + %448 = select i1 %435, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %447, ptr addrspace(1) %413, i32 %448) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + %449 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %437, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %449, ptr addrspace(1) %418, i32 %439) #3, !dbg !104 + %450 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %440, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %450, ptr addrspace(1) %419, i32 %442) #3, !dbg !104 + %451 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %443, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %451, ptr addrspace(1) %420, i32 %445) #3, !dbg !104 + %452 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %446, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %452, ptr addrspace(1) %421, i32 %448) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + %453 = icmp sgt i32 %426, 1, !dbg !102 + %454 = getelementptr i8, ptr addrspace(1) %410, i64 131072, !dbg !105 + %455 = getelementptr i8, ptr addrspace(1) %411, i64 131072, !dbg !105 + %456 = getelementptr i8, ptr addrspace(1) %412, i64 131072, !dbg !105 + %457 = getelementptr i8, ptr addrspace(1) %413, i64 131072, !dbg !105 + %458 = getelementptr i8, ptr addrspace(1) %418, i64 131072, !dbg !106 + %459 = getelementptr i8, ptr addrspace(1) %419, i64 131072, !dbg !106 + %460 = getelementptr i8, ptr addrspace(1) %420, i64 131072, !dbg !106 + %461 = getelementptr i8, ptr addrspace(1) %421, i64 131072, !dbg !106 + %462 = or disjoint i32 %394, 64, !dbg !107 + %463 = or disjoint i32 %395, 64, !dbg !107 + %464 = or disjoint i32 %396, 64, !dbg !107 + %465 = or disjoint i32 %397, 64, !dbg !107 + %466 = icmp slt i32 %462, %18, !dbg !103 + %467 = icmp slt i32 %463, %18, !dbg !103 + %468 = icmp slt i32 %464, %18, !dbg !103 + %469 = icmp slt i32 %465, %18, !dbg !103 + %470 = and i1 %453, %466, !dbg !102 + %471 = and i1 %453, %467, !dbg !102 + %472 = and i1 %453, %468, !dbg !102 + %473 = and i1 %453, %469, !dbg !102 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !104 + %474 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %437, !dbg !104 + %475 = select i1 %470, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %474, ptr addrspace(1) %454, i32 %475) #3, !dbg !104 + %476 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %440, !dbg !104 + %477 = select i1 %471, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %476, ptr addrspace(1) %455, i32 %477) #3, !dbg !104 + %478 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %443, !dbg !104 + %479 = select i1 %472, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %478, ptr addrspace(1) %456, i32 %479) #3, !dbg !104 + %480 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %446, !dbg !104 + %481 = select i1 %473, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %480, ptr addrspace(1) %457, i32 %481) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + %482 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %437, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %482, ptr addrspace(1) %458, i32 %475) #3, !dbg !104 + %483 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %440, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %483, ptr addrspace(1) %459, i32 %477) #3, !dbg !104 + %484 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %443, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %484, ptr addrspace(1) %460, i32 %479) #3, !dbg !104 + %485 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %446, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %485, ptr addrspace(1) %461, i32 %481) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !108 + br i1 %427, label %.lr.ph, label %._crit_edge, !dbg !102 + +.lr.ph: ; preds = %67 + %486 = srem i32 %103, %17, !dbg !109 + %487 = sdiv i32 %486, 16, !dbg !110 + %488 = icmp slt i32 %486, 0, !dbg !111 + %489 = and i32 %486, 15, !dbg !112 + %.not844 = icmp ne i32 %489, 0, !dbg !112 + %narrow846 = and i1 %488, %.not844, !dbg !113 + %490 = sext i1 %narrow846 to i32, !dbg !113 + %491 = add nsw i32 %487, %490, !dbg !113 + %492 = srem i32 %102, %17, !dbg !109 + %493 = sdiv i32 %492, 16, !dbg !110 + %494 = icmp slt i32 %492, 0, !dbg !111 + %495 = and i32 %492, 15, !dbg !112 + %.not843 = icmp ne i32 %495, 0, !dbg !112 + %narrow845 = and i1 %494, %.not843, !dbg !113 + %496 = sext i1 %narrow845 to i32, !dbg !113 + %497 = add nsw i32 %493, %496, !dbg !113 + %498 = icmp sgt i32 %486, -1, !dbg !114 + %499 = icmp sgt i32 %492, -1, !dbg !114 + %500 = insertelement <2 x i32> poison, i32 %378, i64 0, !dbg !90 + %501 = shufflevector <2 x i32> %500, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !90 + %502 = shufflevector <8 x i32> %393, <8 x i32> poison, <2 x i32> , !dbg !90 + %503 = or disjoint <2 x i32> %501, %502, !dbg !90 + %504 = shufflevector <8 x i32> %393, <8 x i32> poison, <2 x i32> , !dbg !90 + %505 = or disjoint <2 x i32> %501, %504, !dbg !90 + %506 = shufflevector <8 x i32> %393, <8 x i32> poison, <2 x i32> , !dbg !90 + %507 = or disjoint <2 x i32> %501, %506, !dbg !90 + %508 = shufflevector <8 x i32> %393, <8 x i32> poison, <2 x i32> , !dbg !90 + %509 = or disjoint <2 x i32> %501, %508, !dbg !90 + %510 = shufflevector <4 x i32> %390, <4 x i32> poison, <2 x i32> , !dbg !90 + %511 = or disjoint <2 x i32> %501, %510, !dbg !90 + %512 = shufflevector <4 x i32> %390, <4 x i32> poison, <2 x i32> , !dbg !90 + %513 = or disjoint <2 x i32> %501, %512, !dbg !90 + %514 = or disjoint <2 x i32> %501, %387, !dbg !90 + %515 = insertelement <2 x i32> %385, i32 %384, i64 1, !dbg !90 + %516 = or disjoint <2 x i32> %501, %515, !dbg !90 + %517 = add nsw i32 %426, -2 + %518 = add nsw i32 %426, -1 + %smax = tail call i32 @llvm.smax.i32(i32 %426, i32 1), !dbg !102 + %519 = insertelement <2 x i1> poison, i1 %488, i64 0, !dbg !115 + %520 = shufflevector <2 x i1> %519, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !115 + %521 = insertelement <2 x i32> poison, i32 %486, i64 0, !dbg !116 + %522 = shufflevector <2 x i32> %521, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !116 + %523 = insertelement <2 x i1> poison, i1 %498, i64 0, !dbg !117 + %524 = shufflevector <2 x i1> %523, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !117 + %525 = insertelement <2 x i32> poison, i32 %491, i64 0, !dbg !118 + %526 = shufflevector <2 x i32> %525, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !118 + %527 = insertelement <2 x i32> poison, i32 %18, i64 0, !dbg !103 + %528 = shufflevector <2 x i32> %527, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !103 + %529 = insertelement <2 x float> poison, float %364, i64 0, !dbg !119 + %530 = shufflevector <2 x float> %529, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !119 + %531 = insertelement <2 x i1> poison, i1 %494, i64 0, !dbg !115 + %532 = shufflevector <2 x i1> %531, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !115 + %533 = insertelement <2 x i32> poison, i32 %492, i64 0, !dbg !116 + %534 = shufflevector <2 x i32> %533, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !116 + %535 = insertelement <2 x i1> poison, i1 %499, i64 0, !dbg !117 + %536 = shufflevector <2 x i1> %535, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !117 + %537 = insertelement <2 x i32> poison, i32 %497, i64 0, !dbg !118 + %538 = shufflevector <2 x i32> %537, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !118 + %539 = insertelement <2 x float> poison, float %362, i64 0, !dbg !119 + %540 = shufflevector <2 x float> %539, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !119 + br label %541, !dbg !102 + +541: ; preds = %.lr.ph, %__nv_exp2f.exit1609 + %542 = phi i32 [ 64, %.lr.ph ], [ %2390, %__nv_exp2f.exit1609 ] + %543 = phi i32 [ -1, %.lr.ph ], [ %622, %__nv_exp2f.exit1609 ] + %544 = phi i32 [ 1, %.lr.ph ], [ %2407, %__nv_exp2f.exit1609 ] + %.pn9331626 = phi ptr addrspace(1) [ %461, %.lr.ph ], [ %2400, %__nv_exp2f.exit1609 ] + %.pn9491625 = phi ptr addrspace(1) [ %460, %.lr.ph ], [ %2399, %__nv_exp2f.exit1609 ] + %.pn9651624 = phi ptr addrspace(1) [ %459, %.lr.ph ], [ %2398, %__nv_exp2f.exit1609 ] + %.pn9811623 = phi ptr addrspace(1) [ %458, %.lr.ph ], [ %2397, %__nv_exp2f.exit1609 ] + %.pn9111622 = phi i32 [ %465, %.lr.ph ], [ %2404, %__nv_exp2f.exit1609 ] + %.pn9131621 = phi i32 [ %464, %.lr.ph ], [ %2403, %__nv_exp2f.exit1609 ] + %.pn9151620 = phi i32 [ %463, %.lr.ph ], [ %2402, %__nv_exp2f.exit1609 ] + %.pn9171619 = phi i32 [ %462, %.lr.ph ], [ %2401, %__nv_exp2f.exit1609 ] + %.pn8611618 = phi ptr addrspace(1) [ %457, %.lr.ph ], [ %2396, %__nv_exp2f.exit1609 ] + %.pn8771617 = phi ptr addrspace(1) [ %456, %.lr.ph ], [ %2395, %__nv_exp2f.exit1609 ] + %.pn8931616 = phi ptr addrspace(1) [ %455, %.lr.ph ], [ %2394, %__nv_exp2f.exit1609 ] + %.pn9091615 = phi ptr addrspace(1) [ %454, %.lr.ph ], [ %2393, %__nv_exp2f.exit1609 ] + %545 = phi float [ 0.000000e+00, %.lr.ph ], [ %2297, %__nv_exp2f.exit1609 ] + %546 = phi float [ 0.000000e+00, %.lr.ph ], [ %2298, %__nv_exp2f.exit1609 ] + %547 = phi float [ 0.000000e+00, %.lr.ph ], [ %2299, %__nv_exp2f.exit1609 ] + %548 = phi float [ 0.000000e+00, %.lr.ph ], [ %2300, %__nv_exp2f.exit1609 ] + %549 = phi float [ 0.000000e+00, %.lr.ph ], [ %2301, %__nv_exp2f.exit1609 ] + %550 = phi float [ 0.000000e+00, %.lr.ph ], [ %2302, %__nv_exp2f.exit1609 ] + %551 = phi float [ 0.000000e+00, %.lr.ph ], [ %2303, %__nv_exp2f.exit1609 ] + %552 = phi float [ 0.000000e+00, %.lr.ph ], [ %2304, %__nv_exp2f.exit1609 ] + %553 = phi float [ 0.000000e+00, %.lr.ph ], [ %2305, %__nv_exp2f.exit1609 ] + %554 = phi float [ 0.000000e+00, %.lr.ph ], [ %2306, %__nv_exp2f.exit1609 ] + %555 = phi float [ 0.000000e+00, %.lr.ph ], [ %2307, %__nv_exp2f.exit1609 ] + %556 = phi float [ 0.000000e+00, %.lr.ph ], [ %2308, %__nv_exp2f.exit1609 ] + %557 = phi float [ 0.000000e+00, %.lr.ph ], [ %2309, %__nv_exp2f.exit1609 ] + %558 = phi float [ 0.000000e+00, %.lr.ph ], [ %2310, %__nv_exp2f.exit1609 ] + %559 = phi float [ 0.000000e+00, %.lr.ph ], [ %2311, %__nv_exp2f.exit1609 ] + %560 = phi float [ 0.000000e+00, %.lr.ph ], [ %2312, %__nv_exp2f.exit1609 ] + %561 = phi float [ 0.000000e+00, %.lr.ph ], [ %2313, %__nv_exp2f.exit1609 ] + %562 = phi float [ 0.000000e+00, %.lr.ph ], [ %2314, %__nv_exp2f.exit1609 ] + %563 = phi float [ 0.000000e+00, %.lr.ph ], [ %2315, %__nv_exp2f.exit1609 ] + %564 = phi float [ 0.000000e+00, %.lr.ph ], [ %2316, %__nv_exp2f.exit1609 ] + %565 = phi float [ 0.000000e+00, %.lr.ph ], [ %2317, %__nv_exp2f.exit1609 ] + %566 = phi float [ 0.000000e+00, %.lr.ph ], [ %2318, %__nv_exp2f.exit1609 ] + %567 = phi float [ 0.000000e+00, %.lr.ph ], [ %2319, %__nv_exp2f.exit1609 ] + %568 = phi float [ 0.000000e+00, %.lr.ph ], [ %2320, %__nv_exp2f.exit1609 ] + %569 = phi float [ 0.000000e+00, %.lr.ph ], [ %2321, %__nv_exp2f.exit1609 ] + %570 = phi float [ 0.000000e+00, %.lr.ph ], [ %2322, %__nv_exp2f.exit1609 ] + %571 = phi float [ 0.000000e+00, %.lr.ph ], [ %2323, %__nv_exp2f.exit1609 ] + %572 = phi float [ 0.000000e+00, %.lr.ph ], [ %2324, %__nv_exp2f.exit1609 ] + %573 = phi float [ 0.000000e+00, %.lr.ph ], [ %2325, %__nv_exp2f.exit1609 ] + %574 = phi float [ 0.000000e+00, %.lr.ph ], [ %2326, %__nv_exp2f.exit1609 ] + %575 = phi float [ 0.000000e+00, %.lr.ph ], [ %2327, %__nv_exp2f.exit1609 ] + %576 = phi float [ 0.000000e+00, %.lr.ph ], [ %2328, %__nv_exp2f.exit1609 ] + %577 = phi float [ 0.000000e+00, %.lr.ph ], [ %2329, %__nv_exp2f.exit1609 ] + %578 = phi float [ 0.000000e+00, %.lr.ph ], [ %2330, %__nv_exp2f.exit1609 ] + %579 = phi float [ 0.000000e+00, %.lr.ph ], [ %2331, %__nv_exp2f.exit1609 ] + %580 = phi float [ 0.000000e+00, %.lr.ph ], [ %2332, %__nv_exp2f.exit1609 ] + %581 = phi float [ 0.000000e+00, %.lr.ph ], [ %2333, %__nv_exp2f.exit1609 ] + %582 = phi float [ 0.000000e+00, %.lr.ph ], [ %2334, %__nv_exp2f.exit1609 ] + %583 = phi float [ 0.000000e+00, %.lr.ph ], [ %2335, %__nv_exp2f.exit1609 ] + %584 = phi float [ 0.000000e+00, %.lr.ph ], [ %2336, %__nv_exp2f.exit1609 ] + %585 = phi float [ 0.000000e+00, %.lr.ph ], [ %2337, %__nv_exp2f.exit1609 ] + %586 = phi float [ 0.000000e+00, %.lr.ph ], [ %2338, %__nv_exp2f.exit1609 ] + %587 = phi float [ 0.000000e+00, %.lr.ph ], [ %2339, %__nv_exp2f.exit1609 ] + %588 = phi float [ 0.000000e+00, %.lr.ph ], [ %2340, %__nv_exp2f.exit1609 ] + %589 = phi float [ 0.000000e+00, %.lr.ph ], [ %2341, %__nv_exp2f.exit1609 ] + %590 = phi float [ 0.000000e+00, %.lr.ph ], [ %2342, %__nv_exp2f.exit1609 ] + %591 = phi float [ 0.000000e+00, %.lr.ph ], [ %2343, %__nv_exp2f.exit1609 ] + %592 = phi float [ 0.000000e+00, %.lr.ph ], [ %2344, %__nv_exp2f.exit1609 ] + %593 = phi float [ 0.000000e+00, %.lr.ph ], [ %2345, %__nv_exp2f.exit1609 ] + %594 = phi float [ 0.000000e+00, %.lr.ph ], [ %2346, %__nv_exp2f.exit1609 ] + %595 = phi float [ 0.000000e+00, %.lr.ph ], [ %2347, %__nv_exp2f.exit1609 ] + %596 = phi float [ 0.000000e+00, %.lr.ph ], [ %2348, %__nv_exp2f.exit1609 ] + %597 = phi float [ 0.000000e+00, %.lr.ph ], [ %2349, %__nv_exp2f.exit1609 ] + %598 = phi float [ 0.000000e+00, %.lr.ph ], [ %2350, %__nv_exp2f.exit1609 ] + %599 = phi float [ 0.000000e+00, %.lr.ph ], [ %2351, %__nv_exp2f.exit1609 ] + %600 = phi float [ 0.000000e+00, %.lr.ph ], [ %2352, %__nv_exp2f.exit1609 ] + %601 = phi float [ 0.000000e+00, %.lr.ph ], [ %2353, %__nv_exp2f.exit1609 ] + %602 = phi float [ 0.000000e+00, %.lr.ph ], [ %2354, %__nv_exp2f.exit1609 ] + %603 = phi float [ 0.000000e+00, %.lr.ph ], [ %2355, %__nv_exp2f.exit1609 ] + %604 = phi float [ 0.000000e+00, %.lr.ph ], [ %2356, %__nv_exp2f.exit1609 ] + %605 = phi float [ 0.000000e+00, %.lr.ph ], [ %2357, %__nv_exp2f.exit1609 ] + %606 = phi float [ 0.000000e+00, %.lr.ph ], [ %2358, %__nv_exp2f.exit1609 ] + %607 = phi float [ 0.000000e+00, %.lr.ph ], [ %2359, %__nv_exp2f.exit1609 ] + %608 = phi float [ 0.000000e+00, %.lr.ph ], [ %2360, %__nv_exp2f.exit1609 ] + %609 = phi i32 [ 0, %.lr.ph ], [ %2371, %__nv_exp2f.exit1609 ] + %610 = phi <2 x i32> [ %503, %.lr.ph ], [ %2370, %__nv_exp2f.exit1609 ] + %611 = phi <2 x i32> [ %505, %.lr.ph ], [ %2369, %__nv_exp2f.exit1609 ] + %612 = phi <2 x i32> [ %507, %.lr.ph ], [ %2368, %__nv_exp2f.exit1609 ] + %613 = phi <2 x i32> [ %509, %.lr.ph ], [ %2367, %__nv_exp2f.exit1609 ] + %614 = phi <2 x i32> [ %511, %.lr.ph ], [ %2366, %__nv_exp2f.exit1609 ] + %615 = phi <2 x i32> [ %513, %.lr.ph ], [ %2365, %__nv_exp2f.exit1609 ] + %616 = phi <2 x i32> [ %514, %.lr.ph ], [ %2364, %__nv_exp2f.exit1609 ] + %617 = phi <2 x i32> [ %516, %.lr.ph ], [ %2363, %__nv_exp2f.exit1609 ] + %618 = icmp slt i32 %609, %517, !dbg !102 + %619 = icmp slt i32 %609, %518, !dbg !102 + %620 = add i32 %543, 1, !dbg !102 + %621 = icmp sgt i32 %620, 2, !dbg !102 + %622 = select i1 %621, i32 0, i32 %620, !dbg !102 + %623 = icmp slt <2 x i32> %617, %528, !dbg !103 + %624 = icmp slt <2 x i32> %616, %528, !dbg !103 + %625 = icmp slt <2 x i32> %615, %528, !dbg !103 + %626 = icmp slt <2 x i32> %614, %528, !dbg !103 + %627 = icmp slt <2 x i32> %613, %528, !dbg !103 + %628 = icmp slt <2 x i32> %612, %528, !dbg !103 + %629 = icmp slt <2 x i32> %611, %528, !dbg !103 + %630 = icmp slt <2 x i32> %610, %528, !dbg !103 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !104 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !104 + %631 = shl i32 %622, 13, !dbg !104 + %632 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %631, !dbg !104 + %633 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %51, i32 0, i32 31), !dbg !108 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !108 + %634 = shl i32 %633, 11, !dbg !108 + %635 = and i32 %634, 8192, !dbg !108 + %636 = add i32 %635, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %637 = lshr exact i32 %636, 4, !dbg !108 + %638 = and i32 %637, 16383, !dbg !108 + %639 = zext nneg i32 %638 to i64, !dbg !108 + %640 = or disjoint i64 %639, 4611686293372403712, !dbg !108 + %641 = ptrtoint ptr addrspace(3) %632 to i32, !dbg !108 + %642 = lshr exact i32 %641, 4, !dbg !108 + %643 = and i32 %642, 16383, !dbg !108 + %644 = zext nneg i32 %643 to i64, !dbg !108 + %645 = or disjoint i64 %644, 4611686293338849280, !dbg !108 + %646 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %640, i64 %645) #3, !dbg !108 + %647 = or disjoint i32 %635, 32, !dbg !108 + %648 = add i32 %647, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %649 = lshr exact i32 %648, 4, !dbg !108 + %650 = and i32 %649, 16383, !dbg !108 + %651 = zext nneg i32 %650 to i64, !dbg !108 + %652 = or disjoint i64 %651, 4611686293372403712, !dbg !108 + %653 = add i32 %641, 32, !dbg !108 + %654 = lshr exact i32 %653, 4, !dbg !108 + %655 = and i32 %654, 16383, !dbg !108 + %656 = zext nneg i32 %655 to i64, !dbg !108 + %657 = or disjoint i64 %656, 4611686293338849280, !dbg !108 + %658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 0, !dbg !108 + %659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 1, !dbg !108 + %660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 2, !dbg !108 + %661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 3, !dbg !108 + %662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 4, !dbg !108 + %663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 5, !dbg !108 + %664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 6, !dbg !108 + %665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 7, !dbg !108 + %666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 8, !dbg !108 + %667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 9, !dbg !108 + %668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 10, !dbg !108 + %669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 11, !dbg !108 + %670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 12, !dbg !108 + %671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 13, !dbg !108 + %672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 14, !dbg !108 + %673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 15, !dbg !108 + %674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 16, !dbg !108 + %675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 17, !dbg !108 + %676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 18, !dbg !108 + %677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 19, !dbg !108 + %678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 20, !dbg !108 + %679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 21, !dbg !108 + %680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 22, !dbg !108 + %681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 23, !dbg !108 + %682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 24, !dbg !108 + %683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 25, !dbg !108 + %684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 26, !dbg !108 + %685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 27, !dbg !108 + %686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 28, !dbg !108 + %687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 29, !dbg !108 + %688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 30, !dbg !108 + %689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %646, 31, !dbg !108 + %690 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %658, float %659, float %660, float %661, float %662, float %663, float %664, float %665, float %666, float %667, float %668, float %669, float %670, float %671, float %672, float %673, float %674, float %675, float %676, float %677, float %678, float %679, float %680, float %681, float %682, float %683, float %684, float %685, float %686, float %687, float %688, float %689, i64 %652, i64 %657, i1 true) #3, !dbg !108 + %691 = or disjoint i32 %635, 64, !dbg !108 + %692 = add i32 %691, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %693 = lshr exact i32 %692, 4, !dbg !108 + %694 = and i32 %693, 16383, !dbg !108 + %695 = zext nneg i32 %694 to i64, !dbg !108 + %696 = or disjoint i64 %695, 4611686293372403712, !dbg !108 + %697 = add i32 %641, 64, !dbg !108 + %698 = lshr exact i32 %697, 4, !dbg !108 + %699 = and i32 %698, 16383, !dbg !108 + %700 = zext nneg i32 %699 to i64, !dbg !108 + %701 = or disjoint i64 %700, 4611686293338849280, !dbg !108 + %702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 0, !dbg !108 + %703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 1, !dbg !108 + %704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 2, !dbg !108 + %705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 3, !dbg !108 + %706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 4, !dbg !108 + %707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 5, !dbg !108 + %708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 6, !dbg !108 + %709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 7, !dbg !108 + %710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 8, !dbg !108 + %711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 9, !dbg !108 + %712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 10, !dbg !108 + %713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 11, !dbg !108 + %714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 12, !dbg !108 + %715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 13, !dbg !108 + %716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 14, !dbg !108 + %717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 15, !dbg !108 + %718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 16, !dbg !108 + %719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 17, !dbg !108 + %720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 18, !dbg !108 + %721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 19, !dbg !108 + %722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 20, !dbg !108 + %723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 21, !dbg !108 + %724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 22, !dbg !108 + %725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 23, !dbg !108 + %726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 24, !dbg !108 + %727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 25, !dbg !108 + %728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 26, !dbg !108 + %729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 27, !dbg !108 + %730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 28, !dbg !108 + %731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 29, !dbg !108 + %732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 30, !dbg !108 + %733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %690, 31, !dbg !108 + %734 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %702, float %703, float %704, float %705, float %706, float %707, float %708, float %709, float %710, float %711, float %712, float %713, float %714, float %715, float %716, float %717, float %718, float %719, float %720, float %721, float %722, float %723, float %724, float %725, float %726, float %727, float %728, float %729, float %730, float %731, float %732, float %733, i64 %696, i64 %701, i1 true) #3, !dbg !108 + %735 = or disjoint i32 %635, 96, !dbg !108 + %736 = add i32 %735, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %737 = lshr exact i32 %736, 4, !dbg !108 + %738 = and i32 %737, 16383, !dbg !108 + %739 = zext nneg i32 %738 to i64, !dbg !108 + %740 = or disjoint i64 %739, 4611686293372403712, !dbg !108 + %741 = add i32 %641, 96, !dbg !108 + %742 = lshr exact i32 %741, 4, !dbg !108 + %743 = and i32 %742, 16383, !dbg !108 + %744 = zext nneg i32 %743 to i64, !dbg !108 + %745 = or disjoint i64 %744, 4611686293338849280, !dbg !108 + %746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 0, !dbg !108 + %747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 1, !dbg !108 + %748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 2, !dbg !108 + %749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 3, !dbg !108 + %750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 4, !dbg !108 + %751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 5, !dbg !108 + %752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 6, !dbg !108 + %753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 7, !dbg !108 + %754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 8, !dbg !108 + %755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 9, !dbg !108 + %756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 10, !dbg !108 + %757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 11, !dbg !108 + %758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 12, !dbg !108 + %759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 13, !dbg !108 + %760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 14, !dbg !108 + %761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 15, !dbg !108 + %762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 16, !dbg !108 + %763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 17, !dbg !108 + %764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 18, !dbg !108 + %765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 19, !dbg !108 + %766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 20, !dbg !108 + %767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 21, !dbg !108 + %768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 22, !dbg !108 + %769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 23, !dbg !108 + %770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 24, !dbg !108 + %771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 25, !dbg !108 + %772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 26, !dbg !108 + %773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 27, !dbg !108 + %774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 28, !dbg !108 + %775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 29, !dbg !108 + %776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 30, !dbg !108 + %777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %734, 31, !dbg !108 + %778 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %746, float %747, float %748, float %749, float %750, float %751, float %752, float %753, float %754, float %755, float %756, float %757, float %758, float %759, float %760, float %761, float %762, float %763, float %764, float %765, float %766, float %767, float %768, float %769, float %770, float %771, float %772, float %773, float %774, float %775, float %776, float %777, i64 %740, i64 %745, i1 true) #3, !dbg !108 + %779 = or disjoint i32 %635, 16384, !dbg !108 + %780 = add i32 %779, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %781 = lshr exact i32 %780, 4, !dbg !108 + %782 = and i32 %781, 16383, !dbg !108 + %783 = zext nneg i32 %782 to i64, !dbg !108 + %784 = or disjoint i64 %783, 4611686293372403712, !dbg !108 + %785 = add i32 %641, 8192, !dbg !108 + %786 = lshr exact i32 %785, 4, !dbg !108 + %787 = and i32 %786, 16383, !dbg !108 + %788 = zext nneg i32 %787 to i64, !dbg !108 + %789 = or disjoint i64 %788, 4611686293338849280, !dbg !108 + %790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 0, !dbg !108 + %791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 1, !dbg !108 + %792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 2, !dbg !108 + %793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 3, !dbg !108 + %794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 4, !dbg !108 + %795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 5, !dbg !108 + %796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 6, !dbg !108 + %797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 7, !dbg !108 + %798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 8, !dbg !108 + %799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 9, !dbg !108 + %800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 10, !dbg !108 + %801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 11, !dbg !108 + %802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 12, !dbg !108 + %803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 13, !dbg !108 + %804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 14, !dbg !108 + %805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 15, !dbg !108 + %806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 16, !dbg !108 + %807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 17, !dbg !108 + %808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 18, !dbg !108 + %809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 19, !dbg !108 + %810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 20, !dbg !108 + %811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 21, !dbg !108 + %812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 22, !dbg !108 + %813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 23, !dbg !108 + %814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 24, !dbg !108 + %815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 25, !dbg !108 + %816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 26, !dbg !108 + %817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 27, !dbg !108 + %818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 28, !dbg !108 + %819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 29, !dbg !108 + %820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 30, !dbg !108 + %821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %778, 31, !dbg !108 + %822 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %790, float %791, float %792, float %793, float %794, float %795, float %796, float %797, float %798, float %799, float %800, float %801, float %802, float %803, float %804, float %805, float %806, float %807, float %808, float %809, float %810, float %811, float %812, float %813, float %814, float %815, float %816, float %817, float %818, float %819, float %820, float %821, i64 %784, i64 %789, i1 true) #3, !dbg !108 + %823 = or disjoint i32 %635, 16416, !dbg !108 + %824 = add i32 %823, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %825 = lshr exact i32 %824, 4, !dbg !108 + %826 = and i32 %825, 16383, !dbg !108 + %827 = zext nneg i32 %826 to i64, !dbg !108 + %828 = or disjoint i64 %827, 4611686293372403712, !dbg !108 + %829 = add i32 %641, 8224, !dbg !108 + %830 = lshr exact i32 %829, 4, !dbg !108 + %831 = and i32 %830, 16383, !dbg !108 + %832 = zext nneg i32 %831 to i64, !dbg !108 + %833 = or disjoint i64 %832, 4611686293338849280, !dbg !108 + %834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 0, !dbg !108 + %835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 1, !dbg !108 + %836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 2, !dbg !108 + %837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 3, !dbg !108 + %838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 4, !dbg !108 + %839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 5, !dbg !108 + %840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 6, !dbg !108 + %841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 7, !dbg !108 + %842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 8, !dbg !108 + %843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 9, !dbg !108 + %844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 10, !dbg !108 + %845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 11, !dbg !108 + %846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 12, !dbg !108 + %847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 13, !dbg !108 + %848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 14, !dbg !108 + %849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 15, !dbg !108 + %850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 16, !dbg !108 + %851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 17, !dbg !108 + %852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 18, !dbg !108 + %853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 19, !dbg !108 + %854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 20, !dbg !108 + %855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 21, !dbg !108 + %856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 22, !dbg !108 + %857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 23, !dbg !108 + %858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 24, !dbg !108 + %859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 25, !dbg !108 + %860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 26, !dbg !108 + %861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 27, !dbg !108 + %862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 28, !dbg !108 + %863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 29, !dbg !108 + %864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 30, !dbg !108 + %865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %822, 31, !dbg !108 + %866 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %834, float %835, float %836, float %837, float %838, float %839, float %840, float %841, float %842, float %843, float %844, float %845, float %846, float %847, float %848, float %849, float %850, float %851, float %852, float %853, float %854, float %855, float %856, float %857, float %858, float %859, float %860, float %861, float %862, float %863, float %864, float %865, i64 %828, i64 %833, i1 true) #3, !dbg !108 + %867 = or disjoint i32 %635, 16448, !dbg !108 + %868 = add i32 %867, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %869 = lshr exact i32 %868, 4, !dbg !108 + %870 = and i32 %869, 16383, !dbg !108 + %871 = zext nneg i32 %870 to i64, !dbg !108 + %872 = or disjoint i64 %871, 4611686293372403712, !dbg !108 + %873 = add i32 %641, 8256, !dbg !108 + %874 = lshr exact i32 %873, 4, !dbg !108 + %875 = and i32 %874, 16383, !dbg !108 + %876 = zext nneg i32 %875 to i64, !dbg !108 + %877 = or disjoint i64 %876, 4611686293338849280, !dbg !108 + %878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 0, !dbg !108 + %879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 1, !dbg !108 + %880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 2, !dbg !108 + %881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 3, !dbg !108 + %882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 4, !dbg !108 + %883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 5, !dbg !108 + %884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 6, !dbg !108 + %885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 7, !dbg !108 + %886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 8, !dbg !108 + %887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 9, !dbg !108 + %888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 10, !dbg !108 + %889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 11, !dbg !108 + %890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 12, !dbg !108 + %891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 13, !dbg !108 + %892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 14, !dbg !108 + %893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 15, !dbg !108 + %894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 16, !dbg !108 + %895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 17, !dbg !108 + %896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 18, !dbg !108 + %897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 19, !dbg !108 + %898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 20, !dbg !108 + %899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 21, !dbg !108 + %900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 22, !dbg !108 + %901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 23, !dbg !108 + %902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 24, !dbg !108 + %903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 25, !dbg !108 + %904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 26, !dbg !108 + %905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 27, !dbg !108 + %906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 28, !dbg !108 + %907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 29, !dbg !108 + %908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 30, !dbg !108 + %909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %866, 31, !dbg !108 + %910 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %878, float %879, float %880, float %881, float %882, float %883, float %884, float %885, float %886, float %887, float %888, float %889, float %890, float %891, float %892, float %893, float %894, float %895, float %896, float %897, float %898, float %899, float %900, float %901, float %902, float %903, float %904, float %905, float %906, float %907, float %908, float %909, i64 %872, i64 %877, i1 true) #3, !dbg !108 + %911 = or disjoint i32 %635, 16480, !dbg !108 + %912 = add i32 %911, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !108 + %913 = lshr exact i32 %912, 4, !dbg !108 + %914 = and i32 %913, 16383, !dbg !108 + %915 = zext nneg i32 %914 to i64, !dbg !108 + %916 = or disjoint i64 %915, 4611686293372403712, !dbg !108 + %917 = add i32 %641, 8288, !dbg !108 + %918 = lshr exact i32 %917, 4, !dbg !108 + %919 = and i32 %918, 16383, !dbg !108 + %920 = zext nneg i32 %919 to i64, !dbg !108 + %921 = or disjoint i64 %920, 4611686293338849280, !dbg !108 + %922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 0, !dbg !108 + %923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 1, !dbg !108 + %924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 2, !dbg !108 + %925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 3, !dbg !108 + %926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 4, !dbg !108 + %927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 5, !dbg !108 + %928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 6, !dbg !108 + %929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 7, !dbg !108 + %930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 8, !dbg !108 + %931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 9, !dbg !108 + %932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 10, !dbg !108 + %933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 11, !dbg !108 + %934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 12, !dbg !108 + %935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 13, !dbg !108 + %936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 14, !dbg !108 + %937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 15, !dbg !108 + %938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 16, !dbg !108 + %939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 17, !dbg !108 + %940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 18, !dbg !108 + %941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 19, !dbg !108 + %942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 20, !dbg !108 + %943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 21, !dbg !108 + %944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 22, !dbg !108 + %945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 23, !dbg !108 + %946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 24, !dbg !108 + %947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 25, !dbg !108 + %948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 26, !dbg !108 + %949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 27, !dbg !108 + %950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 28, !dbg !108 + %951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 29, !dbg !108 + %952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 30, !dbg !108 + %953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %910, 31, !dbg !108 + %954 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %922, float %923, float %924, float %925, float %926, float %927, float %928, float %929, float %930, float %931, float %932, float %933, float %934, float %935, float %936, float %937, float %938, float %939, float %940, float %941, float %942, float %943, float %944, float %945, float %946, float %947, float %948, float %949, float %950, float %951, float %952, float %953, i64 %916, i64 %921, i1 true) #3, !dbg !108 + %955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 0, !dbg !108 + %956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 1, !dbg !108 + %957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 2, !dbg !108 + %958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 3, !dbg !108 + %959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 4, !dbg !108 + %960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 5, !dbg !108 + %961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 6, !dbg !108 + %962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 7, !dbg !108 + %963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 8, !dbg !108 + %964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 9, !dbg !108 + %965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 10, !dbg !108 + %966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 11, !dbg !108 + %967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 12, !dbg !108 + %968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 13, !dbg !108 + %969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 14, !dbg !108 + %970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 15, !dbg !108 + %971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 16, !dbg !108 + %972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 17, !dbg !108 + %973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 18, !dbg !108 + %974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 19, !dbg !108 + %975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 20, !dbg !108 + %976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 21, !dbg !108 + %977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 22, !dbg !108 + %978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 23, !dbg !108 + %979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 24, !dbg !108 + %980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 25, !dbg !108 + %981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 26, !dbg !108 + %982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 27, !dbg !108 + %983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 28, !dbg !108 + %984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 29, !dbg !108 + %985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 30, !dbg !108 + %986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %954, 31, !dbg !108 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !108 + %987 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %955, float %956, float %957, float %958, float %959, float %960, float %961, float %962, float %963, float %964, float %965, float %966, float %967, float %968, float %969, float %970, float %971, float %972, float %973, float %974, float %975, float %976, float %977, float %978, float %979, float %980, float %981, float %982, float %983, float %984, float %985, float %986, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %632, i32 0, i32 0) #3, !dbg !108 + %988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 0, !dbg !108 + %989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 1, !dbg !108 + %990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 2, !dbg !108 + %991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 3, !dbg !108 + %992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 4, !dbg !108 + %993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 5, !dbg !108 + %994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 6, !dbg !108 + %995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 7, !dbg !108 + %996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 8, !dbg !108 + %997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 9, !dbg !108 + %998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 10, !dbg !108 + %999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 11, !dbg !108 + %1000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 12, !dbg !108 + %1001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 13, !dbg !108 + %1002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 14, !dbg !108 + %1003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 15, !dbg !108 + %1004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 16, !dbg !108 + %1005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 17, !dbg !108 + %1006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 18, !dbg !108 + %1007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 19, !dbg !108 + %1008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 20, !dbg !108 + %1009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 21, !dbg !108 + %1010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 22, !dbg !108 + %1011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 23, !dbg !108 + %1012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 24, !dbg !108 + %1013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 25, !dbg !108 + %1014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 26, !dbg !108 + %1015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 27, !dbg !108 + %1016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 28, !dbg !108 + %1017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 29, !dbg !108 + %1018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 30, !dbg !108 + %1019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %987, 31, !dbg !108 + %1020 = fmul float %988, 0x3FB6A09E60000000, !dbg !120 + %1021 = fmul float %989, 0x3FB6A09E60000000, !dbg !120 + %1022 = fmul float %990, 0x3FB6A09E60000000, !dbg !120 + %1023 = fmul float %991, 0x3FB6A09E60000000, !dbg !120 + %1024 = fmul float %992, 0x3FB6A09E60000000, !dbg !120 + %1025 = fmul float %993, 0x3FB6A09E60000000, !dbg !120 + %1026 = fmul float %994, 0x3FB6A09E60000000, !dbg !120 + %1027 = fmul float %995, 0x3FB6A09E60000000, !dbg !120 + %1028 = fmul float %996, 0x3FB6A09E60000000, !dbg !120 + %1029 = fmul float %997, 0x3FB6A09E60000000, !dbg !120 + %1030 = fmul float %998, 0x3FB6A09E60000000, !dbg !120 + %1031 = fmul float %999, 0x3FB6A09E60000000, !dbg !120 + %1032 = fmul float %1000, 0x3FB6A09E60000000, !dbg !120 + %1033 = fmul float %1001, 0x3FB6A09E60000000, !dbg !120 + %1034 = fmul float %1002, 0x3FB6A09E60000000, !dbg !120 + %1035 = fmul float %1003, 0x3FB6A09E60000000, !dbg !120 + %1036 = fmul float %1004, 0x3FB6A09E60000000, !dbg !120 + %1037 = fmul float %1005, 0x3FB6A09E60000000, !dbg !120 + %1038 = fmul float %1006, 0x3FB6A09E60000000, !dbg !120 + %1039 = fmul float %1007, 0x3FB6A09E60000000, !dbg !120 + %1040 = fmul float %1008, 0x3FB6A09E60000000, !dbg !120 + %1041 = fmul float %1009, 0x3FB6A09E60000000, !dbg !120 + %1042 = fmul float %1010, 0x3FB6A09E60000000, !dbg !120 + %1043 = fmul float %1011, 0x3FB6A09E60000000, !dbg !120 + %1044 = fmul float %1012, 0x3FB6A09E60000000, !dbg !120 + %1045 = fmul float %1013, 0x3FB6A09E60000000, !dbg !120 + %1046 = fmul float %1014, 0x3FB6A09E60000000, !dbg !120 + %1047 = fmul float %1015, 0x3FB6A09E60000000, !dbg !120 + %1048 = fmul float %1016, 0x3FB6A09E60000000, !dbg !120 + %1049 = fmul float %1017, 0x3FB6A09E60000000, !dbg !120 + %1050 = fmul float %1018, 0x3FB6A09E60000000, !dbg !120 + %1051 = fmul float %1019, 0x3FB6A09E60000000, !dbg !120 + %1052 = srem <2 x i32> %617, %528, !dbg !109 + %1053 = icmp sle <2 x i32> %1052, %522, !dbg !116 + %1054 = and <2 x i1> %520, %1053, !dbg !115 + %1055 = icmp slt <2 x i32> %1052, zeroinitializer, !dbg !121 + %1056 = and <2 x i1> %524, %1055, !dbg !117 + %1057 = or <2 x i32> %1052, %522, !dbg !122 + %1058 = icmp sgt <2 x i32> %1057, splat (i32 -1), !dbg !122 + %1059 = and <2 x i32> %1052, splat (i32 15), !dbg !123 + %1060 = icmp ne <2 x i32> %1059, zeroinitializer, !dbg !123 + %1061 = sdiv <2 x i32> %1052, splat (i32 16), !dbg !124 + %1062 = and <2 x i1> %1055, %1060, !dbg !125 + %1063 = sext <2 x i1> %1062 to <2 x i32>, !dbg !125 + %1064 = add nsw <2 x i32> %1061, %1063, !dbg !125 + %1065 = icmp eq <2 x i32> %526, %1064, !dbg !118 + %1066 = and <2 x i1> %1058, %1065, !dbg !126 + %1067 = or <2 x i1> %1056, %1066, !dbg !127 + %1068 = or <2 x i1> %1054, %1067, !dbg !128 + %1069 = icmp sle <2 x i32> %1052, %534, !dbg !116 + %1070 = and <2 x i1> %532, %1069, !dbg !115 + %1071 = and <2 x i1> %536, %1055, !dbg !117 + %1072 = or <2 x i32> %1052, %534, !dbg !122 + %1073 = icmp sgt <2 x i32> %1072, splat (i32 -1), !dbg !122 + %1074 = icmp eq <2 x i32> %538, %1064, !dbg !118 + %1075 = and <2 x i1> %1073, %1074, !dbg !126 + %1076 = or <2 x i1> %1071, %1075, !dbg !127 + %1077 = or <2 x i1> %1070, %1076, !dbg !128 + %1078 = select <2 x i1> %1077, <2 x i1> %623, <2 x i1> zeroinitializer, !dbg !129 + %1079 = select <2 x i1> %1068, <2 x i1> %623, <2 x i1> zeroinitializer, !dbg !129 + %1080 = srem <2 x i32> %616, %528, !dbg !109 + %1081 = icmp sle <2 x i32> %1080, %522, !dbg !116 + %1082 = and <2 x i1> %520, %1081, !dbg !115 + %1083 = icmp slt <2 x i32> %1080, zeroinitializer, !dbg !121 + %1084 = and <2 x i1> %524, %1083, !dbg !117 + %1085 = or <2 x i32> %1080, %522, !dbg !122 + %1086 = icmp sgt <2 x i32> %1085, splat (i32 -1), !dbg !122 + %1087 = and <2 x i32> %1080, splat (i32 15), !dbg !123 + %1088 = icmp ne <2 x i32> %1087, zeroinitializer, !dbg !123 + %1089 = sdiv <2 x i32> %1080, splat (i32 16), !dbg !124 + %1090 = and <2 x i1> %1083, %1088, !dbg !125 + %1091 = sext <2 x i1> %1090 to <2 x i32>, !dbg !125 + %1092 = add nsw <2 x i32> %1089, %1091, !dbg !125 + %1093 = icmp eq <2 x i32> %526, %1092, !dbg !118 + %1094 = and <2 x i1> %1086, %1093, !dbg !126 + %1095 = or <2 x i1> %1084, %1094, !dbg !127 + %1096 = or <2 x i1> %1082, %1095, !dbg !128 + %1097 = icmp sle <2 x i32> %1080, %534, !dbg !116 + %1098 = and <2 x i1> %532, %1097, !dbg !115 + %1099 = and <2 x i1> %536, %1083, !dbg !117 + %1100 = or <2 x i32> %1080, %534, !dbg !122 + %1101 = icmp sgt <2 x i32> %1100, splat (i32 -1), !dbg !122 + %1102 = icmp eq <2 x i32> %538, %1092, !dbg !118 + %1103 = and <2 x i1> %1101, %1102, !dbg !126 + %1104 = or <2 x i1> %1099, %1103, !dbg !127 + %1105 = or <2 x i1> %1098, %1104, !dbg !128 + %1106 = select <2 x i1> %1105, <2 x i1> %624, <2 x i1> zeroinitializer, !dbg !129 + %1107 = select <2 x i1> %1096, <2 x i1> %624, <2 x i1> zeroinitializer, !dbg !129 + %1108 = srem <2 x i32> %615, %528, !dbg !109 + %1109 = icmp sle <2 x i32> %1108, %522, !dbg !116 + %1110 = and <2 x i1> %520, %1109, !dbg !115 + %1111 = icmp slt <2 x i32> %1108, zeroinitializer, !dbg !121 + %1112 = and <2 x i1> %524, %1111, !dbg !117 + %1113 = or <2 x i32> %1108, %522, !dbg !122 + %1114 = icmp sgt <2 x i32> %1113, splat (i32 -1), !dbg !122 + %1115 = and <2 x i32> %1108, splat (i32 15), !dbg !123 + %1116 = icmp ne <2 x i32> %1115, zeroinitializer, !dbg !123 + %1117 = sdiv <2 x i32> %1108, splat (i32 16), !dbg !124 + %1118 = and <2 x i1> %1111, %1116, !dbg !125 + %1119 = sext <2 x i1> %1118 to <2 x i32>, !dbg !125 + %1120 = add nsw <2 x i32> %1117, %1119, !dbg !125 + %1121 = icmp eq <2 x i32> %526, %1120, !dbg !118 + %1122 = and <2 x i1> %1114, %1121, !dbg !126 + %1123 = or <2 x i1> %1112, %1122, !dbg !127 + %1124 = or <2 x i1> %1110, %1123, !dbg !128 + %1125 = icmp sle <2 x i32> %1108, %534, !dbg !116 + %1126 = and <2 x i1> %532, %1125, !dbg !115 + %1127 = and <2 x i1> %536, %1111, !dbg !117 + %1128 = or <2 x i32> %1108, %534, !dbg !122 + %1129 = icmp sgt <2 x i32> %1128, splat (i32 -1), !dbg !122 + %1130 = icmp eq <2 x i32> %538, %1120, !dbg !118 + %1131 = and <2 x i1> %1129, %1130, !dbg !126 + %1132 = or <2 x i1> %1127, %1131, !dbg !127 + %1133 = or <2 x i1> %1126, %1132, !dbg !128 + %1134 = select <2 x i1> %1133, <2 x i1> %625, <2 x i1> zeroinitializer, !dbg !129 + %1135 = select <2 x i1> %1124, <2 x i1> %625, <2 x i1> zeroinitializer, !dbg !129 + %1136 = srem <2 x i32> %614, %528, !dbg !109 + %1137 = icmp sle <2 x i32> %1136, %522, !dbg !116 + %1138 = and <2 x i1> %520, %1137, !dbg !115 + %1139 = icmp slt <2 x i32> %1136, zeroinitializer, !dbg !121 + %1140 = and <2 x i1> %524, %1139, !dbg !117 + %1141 = or <2 x i32> %1136, %522, !dbg !122 + %1142 = icmp sgt <2 x i32> %1141, splat (i32 -1), !dbg !122 + %1143 = and <2 x i32> %1136, splat (i32 15), !dbg !123 + %1144 = icmp ne <2 x i32> %1143, zeroinitializer, !dbg !123 + %1145 = sdiv <2 x i32> %1136, splat (i32 16), !dbg !124 + %1146 = and <2 x i1> %1139, %1144, !dbg !125 + %1147 = sext <2 x i1> %1146 to <2 x i32>, !dbg !125 + %1148 = add nsw <2 x i32> %1145, %1147, !dbg !125 + %1149 = icmp eq <2 x i32> %526, %1148, !dbg !118 + %1150 = and <2 x i1> %1142, %1149, !dbg !126 + %1151 = or <2 x i1> %1140, %1150, !dbg !127 + %1152 = or <2 x i1> %1138, %1151, !dbg !128 + %1153 = icmp sle <2 x i32> %1136, %534, !dbg !116 + %1154 = and <2 x i1> %532, %1153, !dbg !115 + %1155 = and <2 x i1> %536, %1139, !dbg !117 + %1156 = or <2 x i32> %1136, %534, !dbg !122 + %1157 = icmp sgt <2 x i32> %1156, splat (i32 -1), !dbg !122 + %1158 = icmp eq <2 x i32> %538, %1148, !dbg !118 + %1159 = and <2 x i1> %1157, %1158, !dbg !126 + %1160 = or <2 x i1> %1155, %1159, !dbg !127 + %1161 = or <2 x i1> %1154, %1160, !dbg !128 + %1162 = select <2 x i1> %1161, <2 x i1> %626, <2 x i1> zeroinitializer, !dbg !129 + %1163 = select <2 x i1> %1152, <2 x i1> %626, <2 x i1> zeroinitializer, !dbg !129 + %1164 = srem <2 x i32> %613, %528, !dbg !109 + %1165 = icmp sle <2 x i32> %1164, %522, !dbg !116 + %1166 = and <2 x i1> %520, %1165, !dbg !115 + %1167 = icmp slt <2 x i32> %1164, zeroinitializer, !dbg !121 + %1168 = and <2 x i1> %524, %1167, !dbg !117 + %1169 = or <2 x i32> %1164, %522, !dbg !122 + %1170 = icmp sgt <2 x i32> %1169, splat (i32 -1), !dbg !122 + %1171 = and <2 x i32> %1164, splat (i32 15), !dbg !123 + %1172 = icmp ne <2 x i32> %1171, zeroinitializer, !dbg !123 + %1173 = sdiv <2 x i32> %1164, splat (i32 16), !dbg !124 + %1174 = and <2 x i1> %1167, %1172, !dbg !125 + %1175 = sext <2 x i1> %1174 to <2 x i32>, !dbg !125 + %1176 = add nsw <2 x i32> %1173, %1175, !dbg !125 + %1177 = icmp eq <2 x i32> %526, %1176, !dbg !118 + %1178 = and <2 x i1> %1170, %1177, !dbg !126 + %1179 = or <2 x i1> %1168, %1178, !dbg !127 + %1180 = or <2 x i1> %1166, %1179, !dbg !128 + %1181 = icmp sle <2 x i32> %1164, %534, !dbg !116 + %1182 = and <2 x i1> %532, %1181, !dbg !115 + %1183 = and <2 x i1> %536, %1167, !dbg !117 + %1184 = or <2 x i32> %1164, %534, !dbg !122 + %1185 = icmp sgt <2 x i32> %1184, splat (i32 -1), !dbg !122 + %1186 = icmp eq <2 x i32> %538, %1176, !dbg !118 + %1187 = and <2 x i1> %1185, %1186, !dbg !126 + %1188 = or <2 x i1> %1183, %1187, !dbg !127 + %1189 = or <2 x i1> %1182, %1188, !dbg !128 + %1190 = select <2 x i1> %1189, <2 x i1> %627, <2 x i1> zeroinitializer, !dbg !129 + %1191 = select <2 x i1> %1180, <2 x i1> %627, <2 x i1> zeroinitializer, !dbg !129 + %1192 = srem <2 x i32> %612, %528, !dbg !109 + %1193 = icmp sle <2 x i32> %1192, %522, !dbg !116 + %1194 = and <2 x i1> %520, %1193, !dbg !115 + %1195 = icmp slt <2 x i32> %1192, zeroinitializer, !dbg !121 + %1196 = and <2 x i1> %524, %1195, !dbg !117 + %1197 = or <2 x i32> %1192, %522, !dbg !122 + %1198 = icmp sgt <2 x i32> %1197, splat (i32 -1), !dbg !122 + %1199 = and <2 x i32> %1192, splat (i32 15), !dbg !123 + %1200 = icmp ne <2 x i32> %1199, zeroinitializer, !dbg !123 + %1201 = sdiv <2 x i32> %1192, splat (i32 16), !dbg !124 + %1202 = and <2 x i1> %1195, %1200, !dbg !125 + %1203 = sext <2 x i1> %1202 to <2 x i32>, !dbg !125 + %1204 = add nsw <2 x i32> %1201, %1203, !dbg !125 + %1205 = icmp eq <2 x i32> %526, %1204, !dbg !118 + %1206 = and <2 x i1> %1198, %1205, !dbg !126 + %1207 = or <2 x i1> %1196, %1206, !dbg !127 + %1208 = or <2 x i1> %1194, %1207, !dbg !128 + %1209 = icmp sle <2 x i32> %1192, %534, !dbg !116 + %1210 = and <2 x i1> %532, %1209, !dbg !115 + %1211 = and <2 x i1> %536, %1195, !dbg !117 + %1212 = or <2 x i32> %1192, %534, !dbg !122 + %1213 = icmp sgt <2 x i32> %1212, splat (i32 -1), !dbg !122 + %1214 = icmp eq <2 x i32> %538, %1204, !dbg !118 + %1215 = and <2 x i1> %1213, %1214, !dbg !126 + %1216 = or <2 x i1> %1211, %1215, !dbg !127 + %1217 = or <2 x i1> %1210, %1216, !dbg !128 + %1218 = select <2 x i1> %1217, <2 x i1> %628, <2 x i1> zeroinitializer, !dbg !129 + %1219 = select <2 x i1> %1208, <2 x i1> %628, <2 x i1> zeroinitializer, !dbg !129 + %1220 = srem <2 x i32> %611, %528, !dbg !109 + %1221 = icmp sle <2 x i32> %1220, %522, !dbg !116 + %1222 = and <2 x i1> %520, %1221, !dbg !115 + %1223 = icmp slt <2 x i32> %1220, zeroinitializer, !dbg !121 + %1224 = and <2 x i1> %524, %1223, !dbg !117 + %1225 = or <2 x i32> %1220, %522, !dbg !122 + %1226 = icmp sgt <2 x i32> %1225, splat (i32 -1), !dbg !122 + %1227 = and <2 x i32> %1220, splat (i32 15), !dbg !123 + %1228 = icmp ne <2 x i32> %1227, zeroinitializer, !dbg !123 + %1229 = sdiv <2 x i32> %1220, splat (i32 16), !dbg !124 + %1230 = and <2 x i1> %1223, %1228, !dbg !125 + %1231 = sext <2 x i1> %1230 to <2 x i32>, !dbg !125 + %1232 = add nsw <2 x i32> %1229, %1231, !dbg !125 + %1233 = icmp eq <2 x i32> %526, %1232, !dbg !118 + %1234 = and <2 x i1> %1226, %1233, !dbg !126 + %1235 = or <2 x i1> %1224, %1234, !dbg !127 + %1236 = or <2 x i1> %1222, %1235, !dbg !128 + %1237 = icmp sle <2 x i32> %1220, %534, !dbg !116 + %1238 = and <2 x i1> %532, %1237, !dbg !115 + %1239 = and <2 x i1> %536, %1223, !dbg !117 + %1240 = or <2 x i32> %1220, %534, !dbg !122 + %1241 = icmp sgt <2 x i32> %1240, splat (i32 -1), !dbg !122 + %1242 = icmp eq <2 x i32> %538, %1232, !dbg !118 + %1243 = and <2 x i1> %1241, %1242, !dbg !126 + %1244 = or <2 x i1> %1239, %1243, !dbg !127 + %1245 = or <2 x i1> %1238, %1244, !dbg !128 + %1246 = select <2 x i1> %1245, <2 x i1> %629, <2 x i1> zeroinitializer, !dbg !129 + %1247 = select <2 x i1> %1236, <2 x i1> %629, <2 x i1> zeroinitializer, !dbg !129 + %1248 = srem <2 x i32> %610, %528, !dbg !109 + %1249 = icmp sle <2 x i32> %1248, %522, !dbg !116 + %1250 = and <2 x i1> %520, %1249, !dbg !115 + %1251 = icmp slt <2 x i32> %1248, zeroinitializer, !dbg !121 + %1252 = and <2 x i1> %524, %1251, !dbg !117 + %1253 = or <2 x i32> %1248, %522, !dbg !122 + %1254 = icmp sgt <2 x i32> %1253, splat (i32 -1), !dbg !122 + %1255 = and <2 x i32> %1248, splat (i32 15), !dbg !123 + %1256 = icmp ne <2 x i32> %1255, zeroinitializer, !dbg !123 + %1257 = sdiv <2 x i32> %1248, splat (i32 16), !dbg !124 + %1258 = and <2 x i1> %1251, %1256, !dbg !125 + %1259 = sext <2 x i1> %1258 to <2 x i32>, !dbg !125 + %1260 = add nsw <2 x i32> %1257, %1259, !dbg !125 + %1261 = icmp eq <2 x i32> %526, %1260, !dbg !118 + %1262 = and <2 x i1> %1254, %1261, !dbg !126 + %1263 = or <2 x i1> %1252, %1262, !dbg !127 + %1264 = or <2 x i1> %1250, %1263, !dbg !128 + %1265 = icmp sle <2 x i32> %1248, %534, !dbg !116 + %1266 = and <2 x i1> %532, %1265, !dbg !115 + %1267 = and <2 x i1> %536, %1251, !dbg !117 + %1268 = or <2 x i32> %1248, %534, !dbg !122 + %1269 = icmp sgt <2 x i32> %1268, splat (i32 -1), !dbg !122 + %1270 = icmp eq <2 x i32> %538, %1260, !dbg !118 + %1271 = and <2 x i1> %1269, %1270, !dbg !126 + %1272 = or <2 x i1> %1267, %1271, !dbg !127 + %1273 = or <2 x i1> %1266, %1272, !dbg !128 + %1274 = select <2 x i1> %1273, <2 x i1> %630, <2 x i1> zeroinitializer, !dbg !129 + %1275 = select <2 x i1> %1264, <2 x i1> %630, <2 x i1> zeroinitializer, !dbg !129 + %1276 = fmul float %1020, 0x3FF7154760000000, !dbg !130 + %1277 = extractelement <2 x i1> %1078, i64 0, !dbg !129 + %1278 = select i1 %1277, float %1276, float 0xFFF0000000000000, !dbg !129 + %1279 = fmul float %1021, 0x3FF7154760000000, !dbg !130 + %1280 = extractelement <2 x i1> %1078, i64 1, !dbg !129 + %1281 = select i1 %1280, float %1279, float 0xFFF0000000000000, !dbg !129 + %1282 = fmul float %1022, 0x3FF7154760000000, !dbg !130 + %1283 = extractelement <2 x i1> %1079, i64 0, !dbg !129 + %1284 = select i1 %1283, float %1282, float 0xFFF0000000000000, !dbg !129 + %1285 = fmul float %1023, 0x3FF7154760000000, !dbg !130 + %1286 = extractelement <2 x i1> %1079, i64 1, !dbg !129 + %1287 = select i1 %1286, float %1285, float 0xFFF0000000000000, !dbg !129 + %1288 = fmul float %1024, 0x3FF7154760000000, !dbg !130 + %1289 = extractelement <2 x i1> %1106, i64 0, !dbg !129 + %1290 = select i1 %1289, float %1288, float 0xFFF0000000000000, !dbg !129 + %1291 = fmul float %1025, 0x3FF7154760000000, !dbg !130 + %1292 = extractelement <2 x i1> %1106, i64 1, !dbg !129 + %1293 = select i1 %1292, float %1291, float 0xFFF0000000000000, !dbg !129 + %1294 = fmul float %1026, 0x3FF7154760000000, !dbg !130 + %1295 = extractelement <2 x i1> %1107, i64 0, !dbg !129 + %1296 = select i1 %1295, float %1294, float 0xFFF0000000000000, !dbg !129 + %1297 = fmul float %1027, 0x3FF7154760000000, !dbg !130 + %1298 = extractelement <2 x i1> %1107, i64 1, !dbg !129 + %1299 = select i1 %1298, float %1297, float 0xFFF0000000000000, !dbg !129 + %1300 = fmul float %1028, 0x3FF7154760000000, !dbg !130 + %1301 = extractelement <2 x i1> %1134, i64 0, !dbg !129 + %1302 = select i1 %1301, float %1300, float 0xFFF0000000000000, !dbg !129 + %1303 = fmul float %1029, 0x3FF7154760000000, !dbg !130 + %1304 = extractelement <2 x i1> %1134, i64 1, !dbg !129 + %1305 = select i1 %1304, float %1303, float 0xFFF0000000000000, !dbg !129 + %1306 = fmul float %1030, 0x3FF7154760000000, !dbg !130 + %1307 = extractelement <2 x i1> %1135, i64 0, !dbg !129 + %1308 = select i1 %1307, float %1306, float 0xFFF0000000000000, !dbg !129 + %1309 = fmul float %1031, 0x3FF7154760000000, !dbg !130 + %1310 = extractelement <2 x i1> %1135, i64 1, !dbg !129 + %1311 = select i1 %1310, float %1309, float 0xFFF0000000000000, !dbg !129 + %1312 = fmul float %1032, 0x3FF7154760000000, !dbg !130 + %1313 = extractelement <2 x i1> %1162, i64 0, !dbg !129 + %1314 = select i1 %1313, float %1312, float 0xFFF0000000000000, !dbg !129 + %1315 = fmul float %1033, 0x3FF7154760000000, !dbg !130 + %1316 = extractelement <2 x i1> %1162, i64 1, !dbg !129 + %1317 = select i1 %1316, float %1315, float 0xFFF0000000000000, !dbg !129 + %1318 = fmul float %1034, 0x3FF7154760000000, !dbg !130 + %1319 = extractelement <2 x i1> %1163, i64 0, !dbg !129 + %1320 = select i1 %1319, float %1318, float 0xFFF0000000000000, !dbg !129 + %1321 = fmul float %1035, 0x3FF7154760000000, !dbg !130 + %1322 = extractelement <2 x i1> %1163, i64 1, !dbg !129 + %1323 = select i1 %1322, float %1321, float 0xFFF0000000000000, !dbg !129 + %1324 = fmul float %1036, 0x3FF7154760000000, !dbg !130 + %1325 = extractelement <2 x i1> %1190, i64 0, !dbg !129 + %1326 = select i1 %1325, float %1324, float 0xFFF0000000000000, !dbg !129 + %1327 = fmul float %1037, 0x3FF7154760000000, !dbg !130 + %1328 = extractelement <2 x i1> %1190, i64 1, !dbg !129 + %1329 = select i1 %1328, float %1327, float 0xFFF0000000000000, !dbg !129 + %1330 = fmul float %1038, 0x3FF7154760000000, !dbg !130 + %1331 = extractelement <2 x i1> %1191, i64 0, !dbg !129 + %1332 = select i1 %1331, float %1330, float 0xFFF0000000000000, !dbg !129 + %1333 = fmul float %1039, 0x3FF7154760000000, !dbg !130 + %1334 = extractelement <2 x i1> %1191, i64 1, !dbg !129 + %1335 = select i1 %1334, float %1333, float 0xFFF0000000000000, !dbg !129 + %1336 = fmul float %1040, 0x3FF7154760000000, !dbg !130 + %1337 = extractelement <2 x i1> %1218, i64 0, !dbg !129 + %1338 = select i1 %1337, float %1336, float 0xFFF0000000000000, !dbg !129 + %1339 = fmul float %1041, 0x3FF7154760000000, !dbg !130 + %1340 = extractelement <2 x i1> %1218, i64 1, !dbg !129 + %1341 = select i1 %1340, float %1339, float 0xFFF0000000000000, !dbg !129 + %1342 = fmul float %1042, 0x3FF7154760000000, !dbg !130 + %1343 = extractelement <2 x i1> %1219, i64 0, !dbg !129 + %1344 = select i1 %1343, float %1342, float 0xFFF0000000000000, !dbg !129 + %1345 = fmul float %1043, 0x3FF7154760000000, !dbg !130 + %1346 = extractelement <2 x i1> %1219, i64 1, !dbg !129 + %1347 = select i1 %1346, float %1345, float 0xFFF0000000000000, !dbg !129 + %1348 = fmul float %1044, 0x3FF7154760000000, !dbg !130 + %1349 = extractelement <2 x i1> %1246, i64 0, !dbg !129 + %1350 = select i1 %1349, float %1348, float 0xFFF0000000000000, !dbg !129 + %1351 = fmul float %1045, 0x3FF7154760000000, !dbg !130 + %1352 = extractelement <2 x i1> %1246, i64 1, !dbg !129 + %1353 = select i1 %1352, float %1351, float 0xFFF0000000000000, !dbg !129 + %1354 = fmul float %1046, 0x3FF7154760000000, !dbg !130 + %1355 = extractelement <2 x i1> %1247, i64 0, !dbg !129 + %1356 = select i1 %1355, float %1354, float 0xFFF0000000000000, !dbg !129 + %1357 = fmul float %1047, 0x3FF7154760000000, !dbg !130 + %1358 = extractelement <2 x i1> %1247, i64 1, !dbg !129 + %1359 = select i1 %1358, float %1357, float 0xFFF0000000000000, !dbg !129 + %1360 = fmul float %1048, 0x3FF7154760000000, !dbg !130 + %1361 = extractelement <2 x i1> %1274, i64 0, !dbg !129 + %1362 = select i1 %1361, float %1360, float 0xFFF0000000000000, !dbg !129 + %1363 = fmul float %1049, 0x3FF7154760000000, !dbg !130 + %1364 = extractelement <2 x i1> %1274, i64 1, !dbg !129 + %1365 = select i1 %1364, float %1363, float 0xFFF0000000000000, !dbg !129 + %1366 = fmul float %1050, 0x3FF7154760000000, !dbg !130 + %1367 = extractelement <2 x i1> %1275, i64 0, !dbg !129 + %1368 = select i1 %1367, float %1366, float 0xFFF0000000000000, !dbg !129 + %1369 = fmul float %1051, 0x3FF7154760000000, !dbg !130 + %1370 = extractelement <2 x i1> %1275, i64 1, !dbg !129 + %1371 = select i1 %1370, float %1369, float 0xFFF0000000000000, !dbg !129 + %1372 = fsub float %1278, %373, !dbg !131 + %1373 = fsub float %1281, %373, !dbg !131 + %1374 = fsub float %1284, %374, !dbg !131 + %1375 = fsub float %1287, %374, !dbg !131 + %1376 = fsub float %1290, %373, !dbg !131 + %1377 = fsub float %1293, %373, !dbg !131 + %1378 = fsub float %1296, %374, !dbg !131 + %1379 = fsub float %1299, %374, !dbg !131 + %1380 = fsub float %1302, %373, !dbg !131 + %1381 = fsub float %1305, %373, !dbg !131 + %1382 = fsub float %1308, %374, !dbg !131 + %1383 = fsub float %1311, %374, !dbg !131 + %1384 = fsub float %1314, %373, !dbg !131 + %1385 = fsub float %1317, %373, !dbg !131 + %1386 = fsub float %1320, %374, !dbg !131 + %1387 = fsub float %1323, %374, !dbg !131 + %1388 = fsub float %1326, %373, !dbg !131 + %1389 = fsub float %1329, %373, !dbg !131 + %1390 = fsub float %1332, %374, !dbg !131 + %1391 = fsub float %1335, %374, !dbg !131 + %1392 = fsub float %1338, %373, !dbg !131 + %1393 = fsub float %1341, %373, !dbg !131 + %1394 = fsub float %1344, %374, !dbg !131 + %1395 = fsub float %1347, %374, !dbg !131 + %1396 = fsub float %1350, %373, !dbg !131 + %1397 = fsub float %1353, %373, !dbg !131 + %1398 = fsub float %1356, %374, !dbg !131 + %1399 = fsub float %1359, %374, !dbg !131 + %1400 = fsub float %1362, %373, !dbg !131 + %1401 = fsub float %1365, %373, !dbg !131 + %1402 = fsub float %1368, %374, !dbg !131 + %1403 = fsub float %1371, %374, !dbg !131 + %1404 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1514 = icmp eq i32 %1404, 0, !dbg !132 + br i1 %.not.i1514, label %1407, label %1405, !dbg !132 + +1405: ; preds = %541 + %1406 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1372) #3, !dbg !132 + br label %__nv_exp2f.exit1516, !dbg !132 + +1407: ; preds = %541 + %1408 = tail call float @llvm.nvvm.ex2.approx.f(float %1372) #3, !dbg !132 + br label %__nv_exp2f.exit1516, !dbg !132 + +__nv_exp2f.exit1516: ; preds = %1405, %1407 + %.0.i1515 = phi float [ %1406, %1405 ], [ %1408, %1407 ], !dbg !132 + %1409 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1517 = icmp eq i32 %1409, 0, !dbg !132 + br i1 %.not.i1517, label %1412, label %1410, !dbg !132 + +1410: ; preds = %__nv_exp2f.exit1516 + %1411 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1373) #3, !dbg !132 + br label %__nv_exp2f.exit1519, !dbg !132 + +1412: ; preds = %__nv_exp2f.exit1516 + %1413 = tail call float @llvm.nvvm.ex2.approx.f(float %1373) #3, !dbg !132 + br label %__nv_exp2f.exit1519, !dbg !132 + +__nv_exp2f.exit1519: ; preds = %1410, %1412 + %.0.i1518 = phi float [ %1411, %1410 ], [ %1413, %1412 ], !dbg !132 + %1414 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1520 = icmp eq i32 %1414, 0, !dbg !132 + br i1 %.not.i1520, label %1417, label %1415, !dbg !132 + +1415: ; preds = %__nv_exp2f.exit1519 + %1416 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1374) #3, !dbg !132 + br label %__nv_exp2f.exit1522, !dbg !132 + +1417: ; preds = %__nv_exp2f.exit1519 + %1418 = tail call float @llvm.nvvm.ex2.approx.f(float %1374) #3, !dbg !132 + br label %__nv_exp2f.exit1522, !dbg !132 + +__nv_exp2f.exit1522: ; preds = %1415, %1417 + %.0.i1521 = phi float [ %1416, %1415 ], [ %1418, %1417 ], !dbg !132 + %1419 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1523 = icmp eq i32 %1419, 0, !dbg !132 + br i1 %.not.i1523, label %1422, label %1420, !dbg !132 + +1420: ; preds = %__nv_exp2f.exit1522 + %1421 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1375) #3, !dbg !132 + br label %__nv_exp2f.exit1525, !dbg !132 + +1422: ; preds = %__nv_exp2f.exit1522 + %1423 = tail call float @llvm.nvvm.ex2.approx.f(float %1375) #3, !dbg !132 + br label %__nv_exp2f.exit1525, !dbg !132 + +__nv_exp2f.exit1525: ; preds = %1420, %1422 + %.0.i1524 = phi float [ %1421, %1420 ], [ %1423, %1422 ], !dbg !132 + %1424 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1526 = icmp eq i32 %1424, 0, !dbg !132 + br i1 %.not.i1526, label %1427, label %1425, !dbg !132 + +1425: ; preds = %__nv_exp2f.exit1525 + %1426 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1376) #3, !dbg !132 + br label %__nv_exp2f.exit1528, !dbg !132 + +1427: ; preds = %__nv_exp2f.exit1525 + %1428 = tail call float @llvm.nvvm.ex2.approx.f(float %1376) #3, !dbg !132 + br label %__nv_exp2f.exit1528, !dbg !132 + +__nv_exp2f.exit1528: ; preds = %1425, %1427 + %.0.i1527 = phi float [ %1426, %1425 ], [ %1428, %1427 ], !dbg !132 + %1429 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1529 = icmp eq i32 %1429, 0, !dbg !132 + br i1 %.not.i1529, label %1432, label %1430, !dbg !132 + +1430: ; preds = %__nv_exp2f.exit1528 + %1431 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1377) #3, !dbg !132 + br label %__nv_exp2f.exit1531, !dbg !132 + +1432: ; preds = %__nv_exp2f.exit1528 + %1433 = tail call float @llvm.nvvm.ex2.approx.f(float %1377) #3, !dbg !132 + br label %__nv_exp2f.exit1531, !dbg !132 + +__nv_exp2f.exit1531: ; preds = %1430, %1432 + %.0.i1530 = phi float [ %1431, %1430 ], [ %1433, %1432 ], !dbg !132 + %1434 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1532 = icmp eq i32 %1434, 0, !dbg !132 + br i1 %.not.i1532, label %1437, label %1435, !dbg !132 + +1435: ; preds = %__nv_exp2f.exit1531 + %1436 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1378) #3, !dbg !132 + br label %__nv_exp2f.exit1534, !dbg !132 + +1437: ; preds = %__nv_exp2f.exit1531 + %1438 = tail call float @llvm.nvvm.ex2.approx.f(float %1378) #3, !dbg !132 + br label %__nv_exp2f.exit1534, !dbg !132 + +__nv_exp2f.exit1534: ; preds = %1435, %1437 + %.0.i1533 = phi float [ %1436, %1435 ], [ %1438, %1437 ], !dbg !132 + %1439 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1535 = icmp eq i32 %1439, 0, !dbg !132 + br i1 %.not.i1535, label %1442, label %1440, !dbg !132 + +1440: ; preds = %__nv_exp2f.exit1534 + %1441 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1379) #3, !dbg !132 + br label %__nv_exp2f.exit1537, !dbg !132 + +1442: ; preds = %__nv_exp2f.exit1534 + %1443 = tail call float @llvm.nvvm.ex2.approx.f(float %1379) #3, !dbg !132 + br label %__nv_exp2f.exit1537, !dbg !132 + +__nv_exp2f.exit1537: ; preds = %1440, %1442 + %.0.i1536 = phi float [ %1441, %1440 ], [ %1443, %1442 ], !dbg !132 + %1444 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1538 = icmp eq i32 %1444, 0, !dbg !132 + br i1 %.not.i1538, label %1447, label %1445, !dbg !132 + +1445: ; preds = %__nv_exp2f.exit1537 + %1446 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1380) #3, !dbg !132 + br label %__nv_exp2f.exit1540, !dbg !132 + +1447: ; preds = %__nv_exp2f.exit1537 + %1448 = tail call float @llvm.nvvm.ex2.approx.f(float %1380) #3, !dbg !132 + br label %__nv_exp2f.exit1540, !dbg !132 + +__nv_exp2f.exit1540: ; preds = %1445, %1447 + %.0.i1539 = phi float [ %1446, %1445 ], [ %1448, %1447 ], !dbg !132 + %1449 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1541 = icmp eq i32 %1449, 0, !dbg !132 + br i1 %.not.i1541, label %1452, label %1450, !dbg !132 + +1450: ; preds = %__nv_exp2f.exit1540 + %1451 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1381) #3, !dbg !132 + br label %__nv_exp2f.exit1543, !dbg !132 + +1452: ; preds = %__nv_exp2f.exit1540 + %1453 = tail call float @llvm.nvvm.ex2.approx.f(float %1381) #3, !dbg !132 + br label %__nv_exp2f.exit1543, !dbg !132 + +__nv_exp2f.exit1543: ; preds = %1450, %1452 + %.0.i1542 = phi float [ %1451, %1450 ], [ %1453, %1452 ], !dbg !132 + %1454 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1544 = icmp eq i32 %1454, 0, !dbg !132 + br i1 %.not.i1544, label %1457, label %1455, !dbg !132 + +1455: ; preds = %__nv_exp2f.exit1543 + %1456 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1382) #3, !dbg !132 + br label %__nv_exp2f.exit1546, !dbg !132 + +1457: ; preds = %__nv_exp2f.exit1543 + %1458 = tail call float @llvm.nvvm.ex2.approx.f(float %1382) #3, !dbg !132 + br label %__nv_exp2f.exit1546, !dbg !132 + +__nv_exp2f.exit1546: ; preds = %1455, %1457 + %.0.i1545 = phi float [ %1456, %1455 ], [ %1458, %1457 ], !dbg !132 + %1459 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1547 = icmp eq i32 %1459, 0, !dbg !132 + br i1 %.not.i1547, label %1462, label %1460, !dbg !132 + +1460: ; preds = %__nv_exp2f.exit1546 + %1461 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1383) #3, !dbg !132 + br label %__nv_exp2f.exit1549, !dbg !132 + +1462: ; preds = %__nv_exp2f.exit1546 + %1463 = tail call float @llvm.nvvm.ex2.approx.f(float %1383) #3, !dbg !132 + br label %__nv_exp2f.exit1549, !dbg !132 + +__nv_exp2f.exit1549: ; preds = %1460, %1462 + %.0.i1548 = phi float [ %1461, %1460 ], [ %1463, %1462 ], !dbg !132 + %1464 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1550 = icmp eq i32 %1464, 0, !dbg !132 + br i1 %.not.i1550, label %1467, label %1465, !dbg !132 + +1465: ; preds = %__nv_exp2f.exit1549 + %1466 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1384) #3, !dbg !132 + br label %__nv_exp2f.exit1552, !dbg !132 + +1467: ; preds = %__nv_exp2f.exit1549 + %1468 = tail call float @llvm.nvvm.ex2.approx.f(float %1384) #3, !dbg !132 + br label %__nv_exp2f.exit1552, !dbg !132 + +__nv_exp2f.exit1552: ; preds = %1465, %1467 + %.0.i1551 = phi float [ %1466, %1465 ], [ %1468, %1467 ], !dbg !132 + %1469 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1553 = icmp eq i32 %1469, 0, !dbg !132 + br i1 %.not.i1553, label %1472, label %1470, !dbg !132 + +1470: ; preds = %__nv_exp2f.exit1552 + %1471 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1385) #3, !dbg !132 + br label %__nv_exp2f.exit1555, !dbg !132 + +1472: ; preds = %__nv_exp2f.exit1552 + %1473 = tail call float @llvm.nvvm.ex2.approx.f(float %1385) #3, !dbg !132 + br label %__nv_exp2f.exit1555, !dbg !132 + +__nv_exp2f.exit1555: ; preds = %1470, %1472 + %.0.i1554 = phi float [ %1471, %1470 ], [ %1473, %1472 ], !dbg !132 + %1474 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1556 = icmp eq i32 %1474, 0, !dbg !132 + br i1 %.not.i1556, label %1477, label %1475, !dbg !132 + +1475: ; preds = %__nv_exp2f.exit1555 + %1476 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1386) #3, !dbg !132 + br label %__nv_exp2f.exit1558, !dbg !132 + +1477: ; preds = %__nv_exp2f.exit1555 + %1478 = tail call float @llvm.nvvm.ex2.approx.f(float %1386) #3, !dbg !132 + br label %__nv_exp2f.exit1558, !dbg !132 + +__nv_exp2f.exit1558: ; preds = %1475, %1477 + %.0.i1557 = phi float [ %1476, %1475 ], [ %1478, %1477 ], !dbg !132 + %1479 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1559 = icmp eq i32 %1479, 0, !dbg !132 + br i1 %.not.i1559, label %1482, label %1480, !dbg !132 + +1480: ; preds = %__nv_exp2f.exit1558 + %1481 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1387) #3, !dbg !132 + br label %__nv_exp2f.exit1561, !dbg !132 + +1482: ; preds = %__nv_exp2f.exit1558 + %1483 = tail call float @llvm.nvvm.ex2.approx.f(float %1387) #3, !dbg !132 + br label %__nv_exp2f.exit1561, !dbg !132 + +__nv_exp2f.exit1561: ; preds = %1480, %1482 + %.0.i1560 = phi float [ %1481, %1480 ], [ %1483, %1482 ], !dbg !132 + %1484 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1562 = icmp eq i32 %1484, 0, !dbg !132 + br i1 %.not.i1562, label %1487, label %1485, !dbg !132 + +1485: ; preds = %__nv_exp2f.exit1561 + %1486 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1388) #3, !dbg !132 + br label %__nv_exp2f.exit1564, !dbg !132 + +1487: ; preds = %__nv_exp2f.exit1561 + %1488 = tail call float @llvm.nvvm.ex2.approx.f(float %1388) #3, !dbg !132 + br label %__nv_exp2f.exit1564, !dbg !132 + +__nv_exp2f.exit1564: ; preds = %1485, %1487 + %.0.i1563 = phi float [ %1486, %1485 ], [ %1488, %1487 ], !dbg !132 + %1489 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1565 = icmp eq i32 %1489, 0, !dbg !132 + br i1 %.not.i1565, label %1492, label %1490, !dbg !132 + +1490: ; preds = %__nv_exp2f.exit1564 + %1491 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1389) #3, !dbg !132 + br label %__nv_exp2f.exit1567, !dbg !132 + +1492: ; preds = %__nv_exp2f.exit1564 + %1493 = tail call float @llvm.nvvm.ex2.approx.f(float %1389) #3, !dbg !132 + br label %__nv_exp2f.exit1567, !dbg !132 + +__nv_exp2f.exit1567: ; preds = %1490, %1492 + %.0.i1566 = phi float [ %1491, %1490 ], [ %1493, %1492 ], !dbg !132 + %1494 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1568 = icmp eq i32 %1494, 0, !dbg !132 + br i1 %.not.i1568, label %1497, label %1495, !dbg !132 + +1495: ; preds = %__nv_exp2f.exit1567 + %1496 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1390) #3, !dbg !132 + br label %__nv_exp2f.exit1570, !dbg !132 + +1497: ; preds = %__nv_exp2f.exit1567 + %1498 = tail call float @llvm.nvvm.ex2.approx.f(float %1390) #3, !dbg !132 + br label %__nv_exp2f.exit1570, !dbg !132 + +__nv_exp2f.exit1570: ; preds = %1495, %1497 + %.0.i1569 = phi float [ %1496, %1495 ], [ %1498, %1497 ], !dbg !132 + %1499 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1571 = icmp eq i32 %1499, 0, !dbg !132 + br i1 %.not.i1571, label %1502, label %1500, !dbg !132 + +1500: ; preds = %__nv_exp2f.exit1570 + %1501 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1391) #3, !dbg !132 + br label %__nv_exp2f.exit1573, !dbg !132 + +1502: ; preds = %__nv_exp2f.exit1570 + %1503 = tail call float @llvm.nvvm.ex2.approx.f(float %1391) #3, !dbg !132 + br label %__nv_exp2f.exit1573, !dbg !132 + +__nv_exp2f.exit1573: ; preds = %1500, %1502 + %.0.i1572 = phi float [ %1501, %1500 ], [ %1503, %1502 ], !dbg !132 + %1504 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1574 = icmp eq i32 %1504, 0, !dbg !132 + br i1 %.not.i1574, label %1507, label %1505, !dbg !132 + +1505: ; preds = %__nv_exp2f.exit1573 + %1506 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1392) #3, !dbg !132 + br label %__nv_exp2f.exit1576, !dbg !132 + +1507: ; preds = %__nv_exp2f.exit1573 + %1508 = tail call float @llvm.nvvm.ex2.approx.f(float %1392) #3, !dbg !132 + br label %__nv_exp2f.exit1576, !dbg !132 + +__nv_exp2f.exit1576: ; preds = %1505, %1507 + %.0.i1575 = phi float [ %1506, %1505 ], [ %1508, %1507 ], !dbg !132 + %1509 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1577 = icmp eq i32 %1509, 0, !dbg !132 + br i1 %.not.i1577, label %1512, label %1510, !dbg !132 + +1510: ; preds = %__nv_exp2f.exit1576 + %1511 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1393) #3, !dbg !132 + br label %__nv_exp2f.exit1579, !dbg !132 + +1512: ; preds = %__nv_exp2f.exit1576 + %1513 = tail call float @llvm.nvvm.ex2.approx.f(float %1393) #3, !dbg !132 + br label %__nv_exp2f.exit1579, !dbg !132 + +__nv_exp2f.exit1579: ; preds = %1510, %1512 + %.0.i1578 = phi float [ %1511, %1510 ], [ %1513, %1512 ], !dbg !132 + %1514 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1580 = icmp eq i32 %1514, 0, !dbg !132 + br i1 %.not.i1580, label %1517, label %1515, !dbg !132 + +1515: ; preds = %__nv_exp2f.exit1579 + %1516 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1394) #3, !dbg !132 + br label %__nv_exp2f.exit1582, !dbg !132 + +1517: ; preds = %__nv_exp2f.exit1579 + %1518 = tail call float @llvm.nvvm.ex2.approx.f(float %1394) #3, !dbg !132 + br label %__nv_exp2f.exit1582, !dbg !132 + +__nv_exp2f.exit1582: ; preds = %1515, %1517 + %.0.i1581 = phi float [ %1516, %1515 ], [ %1518, %1517 ], !dbg !132 + %1519 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1583 = icmp eq i32 %1519, 0, !dbg !132 + br i1 %.not.i1583, label %1522, label %1520, !dbg !132 + +1520: ; preds = %__nv_exp2f.exit1582 + %1521 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1395) #3, !dbg !132 + br label %__nv_exp2f.exit1585, !dbg !132 + +1522: ; preds = %__nv_exp2f.exit1582 + %1523 = tail call float @llvm.nvvm.ex2.approx.f(float %1395) #3, !dbg !132 + br label %__nv_exp2f.exit1585, !dbg !132 + +__nv_exp2f.exit1585: ; preds = %1520, %1522 + %.0.i1584 = phi float [ %1521, %1520 ], [ %1523, %1522 ], !dbg !132 + %1524 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1586 = icmp eq i32 %1524, 0, !dbg !132 + br i1 %.not.i1586, label %1527, label %1525, !dbg !132 + +1525: ; preds = %__nv_exp2f.exit1585 + %1526 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1396) #3, !dbg !132 + br label %__nv_exp2f.exit1588, !dbg !132 + +1527: ; preds = %__nv_exp2f.exit1585 + %1528 = tail call float @llvm.nvvm.ex2.approx.f(float %1396) #3, !dbg !132 + br label %__nv_exp2f.exit1588, !dbg !132 + +__nv_exp2f.exit1588: ; preds = %1525, %1527 + %.0.i1587 = phi float [ %1526, %1525 ], [ %1528, %1527 ], !dbg !132 + %1529 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1589 = icmp eq i32 %1529, 0, !dbg !132 + br i1 %.not.i1589, label %1532, label %1530, !dbg !132 + +1530: ; preds = %__nv_exp2f.exit1588 + %1531 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1397) #3, !dbg !132 + br label %__nv_exp2f.exit1591, !dbg !132 + +1532: ; preds = %__nv_exp2f.exit1588 + %1533 = tail call float @llvm.nvvm.ex2.approx.f(float %1397) #3, !dbg !132 + br label %__nv_exp2f.exit1591, !dbg !132 + +__nv_exp2f.exit1591: ; preds = %1530, %1532 + %.0.i1590 = phi float [ %1531, %1530 ], [ %1533, %1532 ], !dbg !132 + %1534 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1592 = icmp eq i32 %1534, 0, !dbg !132 + br i1 %.not.i1592, label %1537, label %1535, !dbg !132 + +1535: ; preds = %__nv_exp2f.exit1591 + %1536 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1398) #3, !dbg !132 + br label %__nv_exp2f.exit1594, !dbg !132 + +1537: ; preds = %__nv_exp2f.exit1591 + %1538 = tail call float @llvm.nvvm.ex2.approx.f(float %1398) #3, !dbg !132 + br label %__nv_exp2f.exit1594, !dbg !132 + +__nv_exp2f.exit1594: ; preds = %1535, %1537 + %.0.i1593 = phi float [ %1536, %1535 ], [ %1538, %1537 ], !dbg !132 + %1539 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1595 = icmp eq i32 %1539, 0, !dbg !132 + br i1 %.not.i1595, label %1542, label %1540, !dbg !132 + +1540: ; preds = %__nv_exp2f.exit1594 + %1541 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1399) #3, !dbg !132 + br label %__nv_exp2f.exit1597, !dbg !132 + +1542: ; preds = %__nv_exp2f.exit1594 + %1543 = tail call float @llvm.nvvm.ex2.approx.f(float %1399) #3, !dbg !132 + br label %__nv_exp2f.exit1597, !dbg !132 + +__nv_exp2f.exit1597: ; preds = %1540, %1542 + %.0.i1596 = phi float [ %1541, %1540 ], [ %1543, %1542 ], !dbg !132 + %1544 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1598 = icmp eq i32 %1544, 0, !dbg !132 + br i1 %.not.i1598, label %1547, label %1545, !dbg !132 + +1545: ; preds = %__nv_exp2f.exit1597 + %1546 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1400) #3, !dbg !132 + br label %__nv_exp2f.exit1600, !dbg !132 + +1547: ; preds = %__nv_exp2f.exit1597 + %1548 = tail call float @llvm.nvvm.ex2.approx.f(float %1400) #3, !dbg !132 + br label %__nv_exp2f.exit1600, !dbg !132 + +__nv_exp2f.exit1600: ; preds = %1545, %1547 + %.0.i1599 = phi float [ %1546, %1545 ], [ %1548, %1547 ], !dbg !132 + %1549 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1601 = icmp eq i32 %1549, 0, !dbg !132 + br i1 %.not.i1601, label %1552, label %1550, !dbg !132 + +1550: ; preds = %__nv_exp2f.exit1600 + %1551 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1401) #3, !dbg !132 + br label %__nv_exp2f.exit1603, !dbg !132 + +1552: ; preds = %__nv_exp2f.exit1600 + %1553 = tail call float @llvm.nvvm.ex2.approx.f(float %1401) #3, !dbg !132 + br label %__nv_exp2f.exit1603, !dbg !132 + +__nv_exp2f.exit1603: ; preds = %1550, %1552 + %.0.i1602 = phi float [ %1551, %1550 ], [ %1553, %1552 ], !dbg !132 + %1554 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1604 = icmp eq i32 %1554, 0, !dbg !132 + br i1 %.not.i1604, label %1557, label %1555, !dbg !132 + +1555: ; preds = %__nv_exp2f.exit1603 + %1556 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1402) #3, !dbg !132 + br label %__nv_exp2f.exit1606, !dbg !132 + +1557: ; preds = %__nv_exp2f.exit1603 + %1558 = tail call float @llvm.nvvm.ex2.approx.f(float %1402) #3, !dbg !132 + br label %__nv_exp2f.exit1606, !dbg !132 + +__nv_exp2f.exit1606: ; preds = %1555, %1557 + %.0.i1605 = phi float [ %1556, %1555 ], [ %1558, %1557 ], !dbg !132 + %1559 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !132 + %.not.i1607 = icmp eq i32 %1559, 0, !dbg !132 + br i1 %.not.i1607, label %1562, label %1560, !dbg !132 + +1560: ; preds = %__nv_exp2f.exit1606 + %1561 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1403) #3, !dbg !132 + br label %__nv_exp2f.exit1609, !dbg !132 + +1562: ; preds = %__nv_exp2f.exit1606 + %1563 = tail call float @llvm.nvvm.ex2.approx.f(float %1403) #3, !dbg !132 + br label %__nv_exp2f.exit1609, !dbg !132 + +__nv_exp2f.exit1609: ; preds = %1560, %1562 + %.0.i1608 = phi float [ %1561, %1560 ], [ %1563, %1562 ], !dbg !132 + %1564 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %631, !dbg !104 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !133 + %1565 = add i32 %635, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1566 = lshr exact i32 %1565, 4, !dbg !133 + %1567 = and i32 %1566, 16383, !dbg !133 + %1568 = zext nneg i32 %1567 to i64, !dbg !133 + %1569 = or disjoint i64 %1568, 4611686293372403712, !dbg !133 + %1570 = ptrtoint ptr addrspace(3) %1564 to i32, !dbg !133 + %1571 = lshr exact i32 %1570, 4, !dbg !133 + %1572 = and i32 %1571, 16383, !dbg !133 + %1573 = zext nneg i32 %1572 to i64, !dbg !133 + %1574 = or disjoint i64 %1573, 4611686293338849280, !dbg !133 + %1575 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %1569, i64 %1574) #3, !dbg !133 + %1576 = add i32 %647, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1577 = lshr exact i32 %1576, 4, !dbg !133 + %1578 = and i32 %1577, 16383, !dbg !133 + %1579 = zext nneg i32 %1578 to i64, !dbg !133 + %1580 = or disjoint i64 %1579, 4611686293372403712, !dbg !133 + %1581 = add i32 %1570, 32, !dbg !133 + %1582 = lshr exact i32 %1581, 4, !dbg !133 + %1583 = and i32 %1582, 16383, !dbg !133 + %1584 = zext nneg i32 %1583 to i64, !dbg !133 + %1585 = or disjoint i64 %1584, 4611686293338849280, !dbg !133 + %1586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 0, !dbg !133 + %1587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 1, !dbg !133 + %1588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 2, !dbg !133 + %1589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 3, !dbg !133 + %1590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 4, !dbg !133 + %1591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 5, !dbg !133 + %1592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 6, !dbg !133 + %1593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 7, !dbg !133 + %1594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 8, !dbg !133 + %1595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 9, !dbg !133 + %1596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 10, !dbg !133 + %1597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 11, !dbg !133 + %1598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 12, !dbg !133 + %1599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 13, !dbg !133 + %1600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 14, !dbg !133 + %1601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 15, !dbg !133 + %1602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 16, !dbg !133 + %1603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 17, !dbg !133 + %1604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 18, !dbg !133 + %1605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 19, !dbg !133 + %1606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 20, !dbg !133 + %1607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 21, !dbg !133 + %1608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 22, !dbg !133 + %1609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 23, !dbg !133 + %1610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 24, !dbg !133 + %1611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 25, !dbg !133 + %1612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 26, !dbg !133 + %1613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 27, !dbg !133 + %1614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 28, !dbg !133 + %1615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 29, !dbg !133 + %1616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 30, !dbg !133 + %1617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1575, 31, !dbg !133 + %1618 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1586, float %1587, float %1588, float %1589, float %1590, float %1591, float %1592, float %1593, float %1594, float %1595, float %1596, float %1597, float %1598, float %1599, float %1600, float %1601, float %1602, float %1603, float %1604, float %1605, float %1606, float %1607, float %1608, float %1609, float %1610, float %1611, float %1612, float %1613, float %1614, float %1615, float %1616, float %1617, i64 %1580, i64 %1585, i1 true) #3, !dbg !133 + %1619 = add i32 %691, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1620 = lshr exact i32 %1619, 4, !dbg !133 + %1621 = and i32 %1620, 16383, !dbg !133 + %1622 = zext nneg i32 %1621 to i64, !dbg !133 + %1623 = or disjoint i64 %1622, 4611686293372403712, !dbg !133 + %1624 = add i32 %1570, 64, !dbg !133 + %1625 = lshr exact i32 %1624, 4, !dbg !133 + %1626 = and i32 %1625, 16383, !dbg !133 + %1627 = zext nneg i32 %1626 to i64, !dbg !133 + %1628 = or disjoint i64 %1627, 4611686293338849280, !dbg !133 + %1629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 0, !dbg !133 + %1630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 1, !dbg !133 + %1631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 2, !dbg !133 + %1632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 3, !dbg !133 + %1633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 4, !dbg !133 + %1634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 5, !dbg !133 + %1635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 6, !dbg !133 + %1636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 7, !dbg !133 + %1637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 8, !dbg !133 + %1638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 9, !dbg !133 + %1639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 10, !dbg !133 + %1640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 11, !dbg !133 + %1641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 12, !dbg !133 + %1642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 13, !dbg !133 + %1643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 14, !dbg !133 + %1644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 15, !dbg !133 + %1645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 16, !dbg !133 + %1646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 17, !dbg !133 + %1647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 18, !dbg !133 + %1648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 19, !dbg !133 + %1649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 20, !dbg !133 + %1650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 21, !dbg !133 + %1651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 22, !dbg !133 + %1652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 23, !dbg !133 + %1653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 24, !dbg !133 + %1654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 25, !dbg !133 + %1655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 26, !dbg !133 + %1656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 27, !dbg !133 + %1657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 28, !dbg !133 + %1658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 29, !dbg !133 + %1659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 30, !dbg !133 + %1660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1618, 31, !dbg !133 + %1661 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1629, float %1630, float %1631, float %1632, float %1633, float %1634, float %1635, float %1636, float %1637, float %1638, float %1639, float %1640, float %1641, float %1642, float %1643, float %1644, float %1645, float %1646, float %1647, float %1648, float %1649, float %1650, float %1651, float %1652, float %1653, float %1654, float %1655, float %1656, float %1657, float %1658, float %1659, float %1660, i64 %1623, i64 %1628, i1 true) #3, !dbg !133 + %1662 = add i32 %735, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1663 = lshr exact i32 %1662, 4, !dbg !133 + %1664 = and i32 %1663, 16383, !dbg !133 + %1665 = zext nneg i32 %1664 to i64, !dbg !133 + %1666 = or disjoint i64 %1665, 4611686293372403712, !dbg !133 + %1667 = add i32 %1570, 96, !dbg !133 + %1668 = lshr exact i32 %1667, 4, !dbg !133 + %1669 = and i32 %1668, 16383, !dbg !133 + %1670 = zext nneg i32 %1669 to i64, !dbg !133 + %1671 = or disjoint i64 %1670, 4611686293338849280, !dbg !133 + %1672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 0, !dbg !133 + %1673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 1, !dbg !133 + %1674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 2, !dbg !133 + %1675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 3, !dbg !133 + %1676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 4, !dbg !133 + %1677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 5, !dbg !133 + %1678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 6, !dbg !133 + %1679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 7, !dbg !133 + %1680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 8, !dbg !133 + %1681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 9, !dbg !133 + %1682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 10, !dbg !133 + %1683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 11, !dbg !133 + %1684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 12, !dbg !133 + %1685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 13, !dbg !133 + %1686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 14, !dbg !133 + %1687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 15, !dbg !133 + %1688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 16, !dbg !133 + %1689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 17, !dbg !133 + %1690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 18, !dbg !133 + %1691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 19, !dbg !133 + %1692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 20, !dbg !133 + %1693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 21, !dbg !133 + %1694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 22, !dbg !133 + %1695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 23, !dbg !133 + %1696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 24, !dbg !133 + %1697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 25, !dbg !133 + %1698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 26, !dbg !133 + %1699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 27, !dbg !133 + %1700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 28, !dbg !133 + %1701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 29, !dbg !133 + %1702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 30, !dbg !133 + %1703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1661, 31, !dbg !133 + %1704 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1672, float %1673, float %1674, float %1675, float %1676, float %1677, float %1678, float %1679, float %1680, float %1681, float %1682, float %1683, float %1684, float %1685, float %1686, float %1687, float %1688, float %1689, float %1690, float %1691, float %1692, float %1693, float %1694, float %1695, float %1696, float %1697, float %1698, float %1699, float %1700, float %1701, float %1702, float %1703, i64 %1666, i64 %1671, i1 true) #3, !dbg !133 + %1705 = add i32 %779, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1706 = lshr exact i32 %1705, 4, !dbg !133 + %1707 = and i32 %1706, 16383, !dbg !133 + %1708 = zext nneg i32 %1707 to i64, !dbg !133 + %1709 = or disjoint i64 %1708, 4611686293372403712, !dbg !133 + %1710 = add i32 %1570, 8192, !dbg !133 + %1711 = lshr exact i32 %1710, 4, !dbg !133 + %1712 = and i32 %1711, 16383, !dbg !133 + %1713 = zext nneg i32 %1712 to i64, !dbg !133 + %1714 = or disjoint i64 %1713, 4611686293338849280, !dbg !133 + %1715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 0, !dbg !133 + %1716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 1, !dbg !133 + %1717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 2, !dbg !133 + %1718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 3, !dbg !133 + %1719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 4, !dbg !133 + %1720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 5, !dbg !133 + %1721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 6, !dbg !133 + %1722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 7, !dbg !133 + %1723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 8, !dbg !133 + %1724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 9, !dbg !133 + %1725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 10, !dbg !133 + %1726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 11, !dbg !133 + %1727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 12, !dbg !133 + %1728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 13, !dbg !133 + %1729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 14, !dbg !133 + %1730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 15, !dbg !133 + %1731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 16, !dbg !133 + %1732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 17, !dbg !133 + %1733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 18, !dbg !133 + %1734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 19, !dbg !133 + %1735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 20, !dbg !133 + %1736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 21, !dbg !133 + %1737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 22, !dbg !133 + %1738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 23, !dbg !133 + %1739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 24, !dbg !133 + %1740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 25, !dbg !133 + %1741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 26, !dbg !133 + %1742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 27, !dbg !133 + %1743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 28, !dbg !133 + %1744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 29, !dbg !133 + %1745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 30, !dbg !133 + %1746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1704, 31, !dbg !133 + %1747 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1715, float %1716, float %1717, float %1718, float %1719, float %1720, float %1721, float %1722, float %1723, float %1724, float %1725, float %1726, float %1727, float %1728, float %1729, float %1730, float %1731, float %1732, float %1733, float %1734, float %1735, float %1736, float %1737, float %1738, float %1739, float %1740, float %1741, float %1742, float %1743, float %1744, float %1745, float %1746, i64 %1709, i64 %1714, i1 true) #3, !dbg !133 + %1748 = add i32 %823, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1749 = lshr exact i32 %1748, 4, !dbg !133 + %1750 = and i32 %1749, 16383, !dbg !133 + %1751 = zext nneg i32 %1750 to i64, !dbg !133 + %1752 = or disjoint i64 %1751, 4611686293372403712, !dbg !133 + %1753 = add i32 %1570, 8224, !dbg !133 + %1754 = lshr exact i32 %1753, 4, !dbg !133 + %1755 = and i32 %1754, 16383, !dbg !133 + %1756 = zext nneg i32 %1755 to i64, !dbg !133 + %1757 = or disjoint i64 %1756, 4611686293338849280, !dbg !133 + %1758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 0, !dbg !133 + %1759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 1, !dbg !133 + %1760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 2, !dbg !133 + %1761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 3, !dbg !133 + %1762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 4, !dbg !133 + %1763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 5, !dbg !133 + %1764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 6, !dbg !133 + %1765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 7, !dbg !133 + %1766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 8, !dbg !133 + %1767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 9, !dbg !133 + %1768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 10, !dbg !133 + %1769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 11, !dbg !133 + %1770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 12, !dbg !133 + %1771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 13, !dbg !133 + %1772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 14, !dbg !133 + %1773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 15, !dbg !133 + %1774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 16, !dbg !133 + %1775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 17, !dbg !133 + %1776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 18, !dbg !133 + %1777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 19, !dbg !133 + %1778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 20, !dbg !133 + %1779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 21, !dbg !133 + %1780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 22, !dbg !133 + %1781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 23, !dbg !133 + %1782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 24, !dbg !133 + %1783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 25, !dbg !133 + %1784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 26, !dbg !133 + %1785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 27, !dbg !133 + %1786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 28, !dbg !133 + %1787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 29, !dbg !133 + %1788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 30, !dbg !133 + %1789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1747, 31, !dbg !133 + %1790 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1758, float %1759, float %1760, float %1761, float %1762, float %1763, float %1764, float %1765, float %1766, float %1767, float %1768, float %1769, float %1770, float %1771, float %1772, float %1773, float %1774, float %1775, float %1776, float %1777, float %1778, float %1779, float %1780, float %1781, float %1782, float %1783, float %1784, float %1785, float %1786, float %1787, float %1788, float %1789, i64 %1752, i64 %1757, i1 true) #3, !dbg !133 + %1791 = add i32 %867, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1792 = lshr exact i32 %1791, 4, !dbg !133 + %1793 = and i32 %1792, 16383, !dbg !133 + %1794 = zext nneg i32 %1793 to i64, !dbg !133 + %1795 = or disjoint i64 %1794, 4611686293372403712, !dbg !133 + %1796 = add i32 %1570, 8256, !dbg !133 + %1797 = lshr exact i32 %1796, 4, !dbg !133 + %1798 = and i32 %1797, 16383, !dbg !133 + %1799 = zext nneg i32 %1798 to i64, !dbg !133 + %1800 = or disjoint i64 %1799, 4611686293338849280, !dbg !133 + %1801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 0, !dbg !133 + %1802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 1, !dbg !133 + %1803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 2, !dbg !133 + %1804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 3, !dbg !133 + %1805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 4, !dbg !133 + %1806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 5, !dbg !133 + %1807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 6, !dbg !133 + %1808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 7, !dbg !133 + %1809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 8, !dbg !133 + %1810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 9, !dbg !133 + %1811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 10, !dbg !133 + %1812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 11, !dbg !133 + %1813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 12, !dbg !133 + %1814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 13, !dbg !133 + %1815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 14, !dbg !133 + %1816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 15, !dbg !133 + %1817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 16, !dbg !133 + %1818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 17, !dbg !133 + %1819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 18, !dbg !133 + %1820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 19, !dbg !133 + %1821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 20, !dbg !133 + %1822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 21, !dbg !133 + %1823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 22, !dbg !133 + %1824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 23, !dbg !133 + %1825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 24, !dbg !133 + %1826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 25, !dbg !133 + %1827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 26, !dbg !133 + %1828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 27, !dbg !133 + %1829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 28, !dbg !133 + %1830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 29, !dbg !133 + %1831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 30, !dbg !133 + %1832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1790, 31, !dbg !133 + %1833 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1801, float %1802, float %1803, float %1804, float %1805, float %1806, float %1807, float %1808, float %1809, float %1810, float %1811, float %1812, float %1813, float %1814, float %1815, float %1816, float %1817, float %1818, float %1819, float %1820, float %1821, float %1822, float %1823, float %1824, float %1825, float %1826, float %1827, float %1828, float %1829, float %1830, float %1831, float %1832, i64 %1795, i64 %1800, i1 true) #3, !dbg !133 + %1834 = add i32 %911, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !133 + %1835 = lshr exact i32 %1834, 4, !dbg !133 + %1836 = and i32 %1835, 16383, !dbg !133 + %1837 = zext nneg i32 %1836 to i64, !dbg !133 + %1838 = or disjoint i64 %1837, 4611686293372403712, !dbg !133 + %1839 = add i32 %1570, 8288, !dbg !133 + %1840 = lshr exact i32 %1839, 4, !dbg !133 + %1841 = and i32 %1840, 16383, !dbg !133 + %1842 = zext nneg i32 %1841 to i64, !dbg !133 + %1843 = or disjoint i64 %1842, 4611686293338849280, !dbg !133 + %1844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 0, !dbg !133 + %1845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 1, !dbg !133 + %1846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 2, !dbg !133 + %1847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 3, !dbg !133 + %1848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 4, !dbg !133 + %1849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 5, !dbg !133 + %1850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 6, !dbg !133 + %1851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 7, !dbg !133 + %1852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 8, !dbg !133 + %1853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 9, !dbg !133 + %1854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 10, !dbg !133 + %1855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 11, !dbg !133 + %1856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 12, !dbg !133 + %1857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 13, !dbg !133 + %1858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 14, !dbg !133 + %1859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 15, !dbg !133 + %1860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 16, !dbg !133 + %1861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 17, !dbg !133 + %1862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 18, !dbg !133 + %1863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 19, !dbg !133 + %1864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 20, !dbg !133 + %1865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 21, !dbg !133 + %1866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 22, !dbg !133 + %1867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 23, !dbg !133 + %1868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 24, !dbg !133 + %1869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 25, !dbg !133 + %1870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 26, !dbg !133 + %1871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 27, !dbg !133 + %1872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 28, !dbg !133 + %1873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 29, !dbg !133 + %1874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 30, !dbg !133 + %1875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1833, 31, !dbg !133 + %1876 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1844, float %1845, float %1846, float %1847, float %1848, float %1849, float %1850, float %1851, float %1852, float %1853, float %1854, float %1855, float %1856, float %1857, float %1858, float %1859, float %1860, float %1861, float %1862, float %1863, float %1864, float %1865, float %1866, float %1867, float %1868, float %1869, float %1870, float %1871, float %1872, float %1873, float %1874, float %1875, i64 %1838, i64 %1843, i1 true) #3, !dbg !133 + %1877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 0, !dbg !133 + %1878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 1, !dbg !133 + %1879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 2, !dbg !133 + %1880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 3, !dbg !133 + %1881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 4, !dbg !133 + %1882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 5, !dbg !133 + %1883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 6, !dbg !133 + %1884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 7, !dbg !133 + %1885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 8, !dbg !133 + %1886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 9, !dbg !133 + %1887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 10, !dbg !133 + %1888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 11, !dbg !133 + %1889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 12, !dbg !133 + %1890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 13, !dbg !133 + %1891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 14, !dbg !133 + %1892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 15, !dbg !133 + %1893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 16, !dbg !133 + %1894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 17, !dbg !133 + %1895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 18, !dbg !133 + %1896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 19, !dbg !133 + %1897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 20, !dbg !133 + %1898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 21, !dbg !133 + %1899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 22, !dbg !133 + %1900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 23, !dbg !133 + %1901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 24, !dbg !133 + %1902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 25, !dbg !133 + %1903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 26, !dbg !133 + %1904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 27, !dbg !133 + %1905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 28, !dbg !133 + %1906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 29, !dbg !133 + %1907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 30, !dbg !133 + %1908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1876, 31, !dbg !133 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !133 + %1909 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %1877, float %1878, float %1879, float %1880, float %1881, float %1882, float %1883, float %1884, float %1885, float %1886, float %1887, float %1888, float %1889, float %1890, float %1891, float %1892, float %1893, float %1894, float %1895, float %1896, float %1897, float %1898, float %1899, float %1900, float %1901, float %1902, float %1903, float %1904, float %1905, float %1906, float %1907, float %1908, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %1564, i32 0, i32 0) #3, !dbg !133 + %1910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 0, !dbg !133 + %1911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 1, !dbg !133 + %1912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 2, !dbg !133 + %1913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 3, !dbg !133 + %1914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 4, !dbg !133 + %1915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 5, !dbg !133 + %1916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 6, !dbg !133 + %1917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 7, !dbg !133 + %1918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 8, !dbg !133 + %1919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 9, !dbg !133 + %1920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 10, !dbg !133 + %1921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 11, !dbg !133 + %1922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 12, !dbg !133 + %1923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 13, !dbg !133 + %1924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 14, !dbg !133 + %1925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 15, !dbg !133 + %1926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 16, !dbg !133 + %1927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 17, !dbg !133 + %1928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 18, !dbg !133 + %1929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 19, !dbg !133 + %1930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 20, !dbg !133 + %1931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 21, !dbg !133 + %1932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 22, !dbg !133 + %1933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 23, !dbg !133 + %1934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 24, !dbg !133 + %1935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 25, !dbg !133 + %1936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 26, !dbg !133 + %1937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 27, !dbg !133 + %1938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 28, !dbg !133 + %1939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 29, !dbg !133 + %1940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 30, !dbg !133 + %1941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1909, 31, !dbg !133 + %1942 = insertelement <2 x float> poison, float %1910, i64 0, !dbg !119 + %1943 = insertelement <2 x float> %1942, float %1911, i64 1, !dbg !119 + %1944 = fsub <2 x float> %1943, %540, !dbg !119 + %1945 = insertelement <2 x float> poison, float %.0.i1515, i64 0, !dbg !134 + %1946 = insertelement <2 x float> %1945, float %.0.i1518, i64 1, !dbg !134 + %1947 = fmul <2 x float> %1946, %1944, !dbg !134 + %1948 = fptrunc <2 x float> %1947 to <2 x bfloat>, !dbg !135 + %1949 = select <2 x i1> %1078, <2 x bfloat> %1948, <2 x bfloat> zeroinitializer, !dbg !136 + %1950 = insertelement <2 x float> poison, float %1912, i64 0, !dbg !119 + %1951 = insertelement <2 x float> %1950, float %1913, i64 1, !dbg !119 + %1952 = fsub <2 x float> %1951, %530, !dbg !119 + %1953 = insertelement <2 x float> poison, float %.0.i1521, i64 0, !dbg !134 + %1954 = insertelement <2 x float> %1953, float %.0.i1524, i64 1, !dbg !134 + %1955 = fmul <2 x float> %1954, %1952, !dbg !134 + %1956 = fptrunc <2 x float> %1955 to <2 x bfloat>, !dbg !135 + %1957 = select <2 x i1> %1079, <2 x bfloat> %1956, <2 x bfloat> zeroinitializer, !dbg !136 + %1958 = insertelement <2 x float> poison, float %1914, i64 0, !dbg !119 + %1959 = insertelement <2 x float> %1958, float %1915, i64 1, !dbg !119 + %1960 = fsub <2 x float> %1959, %540, !dbg !119 + %1961 = insertelement <2 x float> poison, float %.0.i1527, i64 0, !dbg !134 + %1962 = insertelement <2 x float> %1961, float %.0.i1530, i64 1, !dbg !134 + %1963 = fmul <2 x float> %1962, %1960, !dbg !134 + %1964 = fptrunc <2 x float> %1963 to <2 x bfloat>, !dbg !135 + %1965 = select <2 x i1> %1106, <2 x bfloat> %1964, <2 x bfloat> zeroinitializer, !dbg !136 + %1966 = insertelement <2 x float> poison, float %1916, i64 0, !dbg !119 + %1967 = insertelement <2 x float> %1966, float %1917, i64 1, !dbg !119 + %1968 = fsub <2 x float> %1967, %530, !dbg !119 + %1969 = insertelement <2 x float> poison, float %.0.i1533, i64 0, !dbg !134 + %1970 = insertelement <2 x float> %1969, float %.0.i1536, i64 1, !dbg !134 + %1971 = fmul <2 x float> %1970, %1968, !dbg !134 + %1972 = fptrunc <2 x float> %1971 to <2 x bfloat>, !dbg !135 + %1973 = select <2 x i1> %1107, <2 x bfloat> %1972, <2 x bfloat> zeroinitializer, !dbg !136 + %1974 = insertelement <2 x float> poison, float %1918, i64 0, !dbg !119 + %1975 = insertelement <2 x float> %1974, float %1919, i64 1, !dbg !119 + %1976 = fsub <2 x float> %1975, %540, !dbg !119 + %1977 = insertelement <2 x float> poison, float %.0.i1539, i64 0, !dbg !134 + %1978 = insertelement <2 x float> %1977, float %.0.i1542, i64 1, !dbg !134 + %1979 = fmul <2 x float> %1978, %1976, !dbg !134 + %1980 = fptrunc <2 x float> %1979 to <2 x bfloat>, !dbg !135 + %1981 = select <2 x i1> %1134, <2 x bfloat> %1980, <2 x bfloat> zeroinitializer, !dbg !136 + %1982 = insertelement <2 x float> poison, float %1920, i64 0, !dbg !119 + %1983 = insertelement <2 x float> %1982, float %1921, i64 1, !dbg !119 + %1984 = fsub <2 x float> %1983, %530, !dbg !119 + %1985 = insertelement <2 x float> poison, float %.0.i1545, i64 0, !dbg !134 + %1986 = insertelement <2 x float> %1985, float %.0.i1548, i64 1, !dbg !134 + %1987 = fmul <2 x float> %1986, %1984, !dbg !134 + %1988 = fptrunc <2 x float> %1987 to <2 x bfloat>, !dbg !135 + %1989 = select <2 x i1> %1135, <2 x bfloat> %1988, <2 x bfloat> zeroinitializer, !dbg !136 + %1990 = insertelement <2 x float> poison, float %1922, i64 0, !dbg !119 + %1991 = insertelement <2 x float> %1990, float %1923, i64 1, !dbg !119 + %1992 = fsub <2 x float> %1991, %540, !dbg !119 + %1993 = insertelement <2 x float> poison, float %.0.i1551, i64 0, !dbg !134 + %1994 = insertelement <2 x float> %1993, float %.0.i1554, i64 1, !dbg !134 + %1995 = fmul <2 x float> %1994, %1992, !dbg !134 + %1996 = fptrunc <2 x float> %1995 to <2 x bfloat>, !dbg !135 + %1997 = select <2 x i1> %1162, <2 x bfloat> %1996, <2 x bfloat> zeroinitializer, !dbg !136 + %1998 = insertelement <2 x float> poison, float %1924, i64 0, !dbg !119 + %1999 = insertelement <2 x float> %1998, float %1925, i64 1, !dbg !119 + %2000 = fsub <2 x float> %1999, %530, !dbg !119 + %2001 = insertelement <2 x float> poison, float %.0.i1557, i64 0, !dbg !134 + %2002 = insertelement <2 x float> %2001, float %.0.i1560, i64 1, !dbg !134 + %2003 = fmul <2 x float> %2002, %2000, !dbg !134 + %2004 = fptrunc <2 x float> %2003 to <2 x bfloat>, !dbg !135 + %2005 = select <2 x i1> %1163, <2 x bfloat> %2004, <2 x bfloat> zeroinitializer, !dbg !136 + %2006 = insertelement <2 x float> poison, float %1926, i64 0, !dbg !119 + %2007 = insertelement <2 x float> %2006, float %1927, i64 1, !dbg !119 + %2008 = fsub <2 x float> %2007, %540, !dbg !119 + %2009 = insertelement <2 x float> poison, float %.0.i1563, i64 0, !dbg !134 + %2010 = insertelement <2 x float> %2009, float %.0.i1566, i64 1, !dbg !134 + %2011 = fmul <2 x float> %2010, %2008, !dbg !134 + %2012 = fptrunc <2 x float> %2011 to <2 x bfloat>, !dbg !135 + %2013 = select <2 x i1> %1190, <2 x bfloat> %2012, <2 x bfloat> zeroinitializer, !dbg !136 + %2014 = insertelement <2 x float> poison, float %1928, i64 0, !dbg !119 + %2015 = insertelement <2 x float> %2014, float %1929, i64 1, !dbg !119 + %2016 = fsub <2 x float> %2015, %530, !dbg !119 + %2017 = insertelement <2 x float> poison, float %.0.i1569, i64 0, !dbg !134 + %2018 = insertelement <2 x float> %2017, float %.0.i1572, i64 1, !dbg !134 + %2019 = fmul <2 x float> %2018, %2016, !dbg !134 + %2020 = fptrunc <2 x float> %2019 to <2 x bfloat>, !dbg !135 + %2021 = select <2 x i1> %1191, <2 x bfloat> %2020, <2 x bfloat> zeroinitializer, !dbg !136 + %2022 = insertelement <2 x float> poison, float %1930, i64 0, !dbg !119 + %2023 = insertelement <2 x float> %2022, float %1931, i64 1, !dbg !119 + %2024 = fsub <2 x float> %2023, %540, !dbg !119 + %2025 = insertelement <2 x float> poison, float %.0.i1575, i64 0, !dbg !134 + %2026 = insertelement <2 x float> %2025, float %.0.i1578, i64 1, !dbg !134 + %2027 = fmul <2 x float> %2026, %2024, !dbg !134 + %2028 = fptrunc <2 x float> %2027 to <2 x bfloat>, !dbg !135 + %2029 = select <2 x i1> %1218, <2 x bfloat> %2028, <2 x bfloat> zeroinitializer, !dbg !136 + %2030 = insertelement <2 x float> poison, float %1932, i64 0, !dbg !119 + %2031 = insertelement <2 x float> %2030, float %1933, i64 1, !dbg !119 + %2032 = fsub <2 x float> %2031, %530, !dbg !119 + %2033 = insertelement <2 x float> poison, float %.0.i1581, i64 0, !dbg !134 + %2034 = insertelement <2 x float> %2033, float %.0.i1584, i64 1, !dbg !134 + %2035 = fmul <2 x float> %2034, %2032, !dbg !134 + %2036 = fptrunc <2 x float> %2035 to <2 x bfloat>, !dbg !135 + %2037 = select <2 x i1> %1219, <2 x bfloat> %2036, <2 x bfloat> zeroinitializer, !dbg !136 + %2038 = insertelement <2 x float> poison, float %1934, i64 0, !dbg !119 + %2039 = insertelement <2 x float> %2038, float %1935, i64 1, !dbg !119 + %2040 = fsub <2 x float> %2039, %540, !dbg !119 + %2041 = insertelement <2 x float> poison, float %.0.i1587, i64 0, !dbg !134 + %2042 = insertelement <2 x float> %2041, float %.0.i1590, i64 1, !dbg !134 + %2043 = fmul <2 x float> %2042, %2040, !dbg !134 + %2044 = fptrunc <2 x float> %2043 to <2 x bfloat>, !dbg !135 + %2045 = select <2 x i1> %1246, <2 x bfloat> %2044, <2 x bfloat> zeroinitializer, !dbg !136 + %2046 = insertelement <2 x float> poison, float %1936, i64 0, !dbg !119 + %2047 = insertelement <2 x float> %2046, float %1937, i64 1, !dbg !119 + %2048 = fsub <2 x float> %2047, %530, !dbg !119 + %2049 = insertelement <2 x float> poison, float %.0.i1593, i64 0, !dbg !134 + %2050 = insertelement <2 x float> %2049, float %.0.i1596, i64 1, !dbg !134 + %2051 = fmul <2 x float> %2050, %2048, !dbg !134 + %2052 = fptrunc <2 x float> %2051 to <2 x bfloat>, !dbg !135 + %2053 = select <2 x i1> %1247, <2 x bfloat> %2052, <2 x bfloat> zeroinitializer, !dbg !136 + %2054 = insertelement <2 x float> poison, float %1938, i64 0, !dbg !119 + %2055 = insertelement <2 x float> %2054, float %1939, i64 1, !dbg !119 + %2056 = fsub <2 x float> %2055, %540, !dbg !119 + %2057 = insertelement <2 x float> poison, float %.0.i1599, i64 0, !dbg !134 + %2058 = insertelement <2 x float> %2057, float %.0.i1602, i64 1, !dbg !134 + %2059 = fmul <2 x float> %2058, %2056, !dbg !134 + %2060 = fptrunc <2 x float> %2059 to <2 x bfloat>, !dbg !135 + %2061 = select <2 x i1> %1274, <2 x bfloat> %2060, <2 x bfloat> zeroinitializer, !dbg !136 + %2062 = insertelement <2 x float> poison, float %1940, i64 0, !dbg !119 + %2063 = insertelement <2 x float> %2062, float %1941, i64 1, !dbg !119 + %2064 = fsub <2 x float> %2063, %530, !dbg !119 + %2065 = insertelement <2 x float> poison, float %.0.i1605, i64 0, !dbg !134 + %2066 = insertelement <2 x float> %2065, float %.0.i1608, i64 1, !dbg !134 + %2067 = fmul <2 x float> %2066, %2064, !dbg !134 + %2068 = fptrunc <2 x float> %2067 to <2 x bfloat>, !dbg !135 + %2069 = select <2 x i1> %1275, <2 x bfloat> %2068, <2 x bfloat> zeroinitializer, !dbg !136 + %2070 = bitcast <2 x bfloat> %1949 to i32, !dbg !137 + %2071 = bitcast <2 x bfloat> %1957 to i32, !dbg !137 + %2072 = bitcast <2 x bfloat> %1965 to i32, !dbg !137 + %2073 = bitcast <2 x bfloat> %1973 to i32, !dbg !137 + %2074 = bitcast <2 x bfloat> %1981 to i32, !dbg !137 + %2075 = bitcast <2 x bfloat> %1989 to i32, !dbg !137 + %2076 = bitcast <2 x bfloat> %1997 to i32, !dbg !137 + %2077 = bitcast <2 x bfloat> %2005 to i32, !dbg !137 + %2078 = bitcast <2 x bfloat> %2013 to i32, !dbg !137 + %2079 = bitcast <2 x bfloat> %2021 to i32, !dbg !137 + %2080 = bitcast <2 x bfloat> %2029 to i32, !dbg !137 + %2081 = bitcast <2 x bfloat> %2037 to i32, !dbg !137 + %2082 = bitcast <2 x bfloat> %2045 to i32, !dbg !137 + %2083 = bitcast <2 x bfloat> %2053 to i32, !dbg !137 + %2084 = bitcast <2 x bfloat> %2061 to i32, !dbg !137 + %2085 = bitcast <2 x bfloat> %2069 to i32, !dbg !137 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !137 + %2086 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %545, float %546, float %547, float %548, float %549, float %550, float %551, float %552, float %553, float %554, float %555, float %556, float %557, float %558, float %559, float %560, float %561, float %562, float %563, float %564, float %565, float %566, float %567, float %568, float %569, float %570, float %571, float %572, float %573, float %574, float %575, float %576, float %577, float %578, float %579, float %580, float %581, float %582, float %583, float %584, float %585, float %586, float %587, float %588, float %589, float %590, float %591, float %592, float %593, float %594, float %595, float %596, float %597, float %598, float %599, float %600, float %601, float %602, float %603, float %604, float %605, float %606, float %607, float %608, i32 %2070, i32 %2071, i32 %2072, i32 %2073, i64 %645, i1 true) #3, !dbg !137 + %2087 = add i32 %641, 2048, !dbg !137 + %2088 = lshr exact i32 %2087, 4, !dbg !137 + %2089 = and i32 %2088, 16383, !dbg !137 + %2090 = zext nneg i32 %2089 to i64, !dbg !137 + %2091 = or disjoint i64 %2090, 4611686293338849280, !dbg !137 + %2092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 0, !dbg !137 + %2093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 1, !dbg !137 + %2094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 2, !dbg !137 + %2095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 3, !dbg !137 + %2096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 4, !dbg !137 + %2097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 5, !dbg !137 + %2098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 6, !dbg !137 + %2099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 7, !dbg !137 + %2100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 8, !dbg !137 + %2101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 9, !dbg !137 + %2102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 10, !dbg !137 + %2103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 11, !dbg !137 + %2104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 12, !dbg !137 + %2105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 13, !dbg !137 + %2106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 14, !dbg !137 + %2107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 15, !dbg !137 + %2108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 16, !dbg !137 + %2109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 17, !dbg !137 + %2110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 18, !dbg !137 + %2111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 19, !dbg !137 + %2112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 20, !dbg !137 + %2113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 21, !dbg !137 + %2114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 22, !dbg !137 + %2115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 23, !dbg !137 + %2116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 24, !dbg !137 + %2117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 25, !dbg !137 + %2118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 26, !dbg !137 + %2119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 27, !dbg !137 + %2120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 28, !dbg !137 + %2121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 29, !dbg !137 + %2122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 30, !dbg !137 + %2123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 31, !dbg !137 + %2124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 32, !dbg !137 + %2125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 33, !dbg !137 + %2126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 34, !dbg !137 + %2127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 35, !dbg !137 + %2128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 36, !dbg !137 + %2129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 37, !dbg !137 + %2130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 38, !dbg !137 + %2131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 39, !dbg !137 + %2132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 40, !dbg !137 + %2133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 41, !dbg !137 + %2134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 42, !dbg !137 + %2135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 43, !dbg !137 + %2136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 44, !dbg !137 + %2137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 45, !dbg !137 + %2138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 46, !dbg !137 + %2139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 47, !dbg !137 + %2140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 48, !dbg !137 + %2141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 49, !dbg !137 + %2142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 50, !dbg !137 + %2143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 51, !dbg !137 + %2144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 52, !dbg !137 + %2145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 53, !dbg !137 + %2146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 54, !dbg !137 + %2147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 55, !dbg !137 + %2148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 56, !dbg !137 + %2149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 57, !dbg !137 + %2150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 58, !dbg !137 + %2151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 59, !dbg !137 + %2152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 60, !dbg !137 + %2153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 61, !dbg !137 + %2154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 62, !dbg !137 + %2155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2086, 63, !dbg !137 + %2156 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2092, float %2093, float %2094, float %2095, float %2096, float %2097, float %2098, float %2099, float %2100, float %2101, float %2102, float %2103, float %2104, float %2105, float %2106, float %2107, float %2108, float %2109, float %2110, float %2111, float %2112, float %2113, float %2114, float %2115, float %2116, float %2117, float %2118, float %2119, float %2120, float %2121, float %2122, float %2123, float %2124, float %2125, float %2126, float %2127, float %2128, float %2129, float %2130, float %2131, float %2132, float %2133, float %2134, float %2135, float %2136, float %2137, float %2138, float %2139, float %2140, float %2141, float %2142, float %2143, float %2144, float %2145, float %2146, float %2147, float %2148, float %2149, float %2150, float %2151, float %2152, float %2153, float %2154, float %2155, i32 %2074, i32 %2075, i32 %2076, i32 %2077, i64 %2091, i1 true) #3, !dbg !137 + %2157 = add i32 %641, 4096, !dbg !137 + %2158 = lshr exact i32 %2157, 4, !dbg !137 + %2159 = and i32 %2158, 16383, !dbg !137 + %2160 = zext nneg i32 %2159 to i64, !dbg !137 + %2161 = or disjoint i64 %2160, 4611686293338849280, !dbg !137 + %2162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 0, !dbg !137 + %2163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 1, !dbg !137 + %2164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 2, !dbg !137 + %2165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 3, !dbg !137 + %2166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 4, !dbg !137 + %2167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 5, !dbg !137 + %2168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 6, !dbg !137 + %2169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 7, !dbg !137 + %2170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 8, !dbg !137 + %2171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 9, !dbg !137 + %2172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 10, !dbg !137 + %2173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 11, !dbg !137 + %2174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 12, !dbg !137 + %2175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 13, !dbg !137 + %2176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 14, !dbg !137 + %2177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 15, !dbg !137 + %2178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 16, !dbg !137 + %2179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 17, !dbg !137 + %2180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 18, !dbg !137 + %2181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 19, !dbg !137 + %2182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 20, !dbg !137 + %2183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 21, !dbg !137 + %2184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 22, !dbg !137 + %2185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 23, !dbg !137 + %2186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 24, !dbg !137 + %2187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 25, !dbg !137 + %2188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 26, !dbg !137 + %2189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 27, !dbg !137 + %2190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 28, !dbg !137 + %2191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 29, !dbg !137 + %2192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 30, !dbg !137 + %2193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 31, !dbg !137 + %2194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 32, !dbg !137 + %2195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 33, !dbg !137 + %2196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 34, !dbg !137 + %2197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 35, !dbg !137 + %2198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 36, !dbg !137 + %2199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 37, !dbg !137 + %2200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 38, !dbg !137 + %2201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 39, !dbg !137 + %2202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 40, !dbg !137 + %2203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 41, !dbg !137 + %2204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 42, !dbg !137 + %2205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 43, !dbg !137 + %2206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 44, !dbg !137 + %2207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 45, !dbg !137 + %2208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 46, !dbg !137 + %2209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 47, !dbg !137 + %2210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 48, !dbg !137 + %2211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 49, !dbg !137 + %2212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 50, !dbg !137 + %2213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 51, !dbg !137 + %2214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 52, !dbg !137 + %2215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 53, !dbg !137 + %2216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 54, !dbg !137 + %2217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 55, !dbg !137 + %2218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 56, !dbg !137 + %2219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 57, !dbg !137 + %2220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 58, !dbg !137 + %2221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 59, !dbg !137 + %2222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 60, !dbg !137 + %2223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 61, !dbg !137 + %2224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 62, !dbg !137 + %2225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2156, 63, !dbg !137 + %2226 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2162, float %2163, float %2164, float %2165, float %2166, float %2167, float %2168, float %2169, float %2170, float %2171, float %2172, float %2173, float %2174, float %2175, float %2176, float %2177, float %2178, float %2179, float %2180, float %2181, float %2182, float %2183, float %2184, float %2185, float %2186, float %2187, float %2188, float %2189, float %2190, float %2191, float %2192, float %2193, float %2194, float %2195, float %2196, float %2197, float %2198, float %2199, float %2200, float %2201, float %2202, float %2203, float %2204, float %2205, float %2206, float %2207, float %2208, float %2209, float %2210, float %2211, float %2212, float %2213, float %2214, float %2215, float %2216, float %2217, float %2218, float %2219, float %2220, float %2221, float %2222, float %2223, float %2224, float %2225, i32 %2078, i32 %2079, i32 %2080, i32 %2081, i64 %2161, i1 true) #3, !dbg !137 + %2227 = add i32 %641, 6144, !dbg !137 + %2228 = lshr exact i32 %2227, 4, !dbg !137 + %2229 = and i32 %2228, 16383, !dbg !137 + %2230 = zext nneg i32 %2229 to i64, !dbg !137 + %2231 = or disjoint i64 %2230, 4611686293338849280, !dbg !137 + %2232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 0, !dbg !137 + %2233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 1, !dbg !137 + %2234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 2, !dbg !137 + %2235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 3, !dbg !137 + %2236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 4, !dbg !137 + %2237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 5, !dbg !137 + %2238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 6, !dbg !137 + %2239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 7, !dbg !137 + %2240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 8, !dbg !137 + %2241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 9, !dbg !137 + %2242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 10, !dbg !137 + %2243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 11, !dbg !137 + %2244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 12, !dbg !137 + %2245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 13, !dbg !137 + %2246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 14, !dbg !137 + %2247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 15, !dbg !137 + %2248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 16, !dbg !137 + %2249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 17, !dbg !137 + %2250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 18, !dbg !137 + %2251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 19, !dbg !137 + %2252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 20, !dbg !137 + %2253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 21, !dbg !137 + %2254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 22, !dbg !137 + %2255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 23, !dbg !137 + %2256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 24, !dbg !137 + %2257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 25, !dbg !137 + %2258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 26, !dbg !137 + %2259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 27, !dbg !137 + %2260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 28, !dbg !137 + %2261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 29, !dbg !137 + %2262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 30, !dbg !137 + %2263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 31, !dbg !137 + %2264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 32, !dbg !137 + %2265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 33, !dbg !137 + %2266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 34, !dbg !137 + %2267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 35, !dbg !137 + %2268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 36, !dbg !137 + %2269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 37, !dbg !137 + %2270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 38, !dbg !137 + %2271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 39, !dbg !137 + %2272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 40, !dbg !137 + %2273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 41, !dbg !137 + %2274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 42, !dbg !137 + %2275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 43, !dbg !137 + %2276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 44, !dbg !137 + %2277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 45, !dbg !137 + %2278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 46, !dbg !137 + %2279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 47, !dbg !137 + %2280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 48, !dbg !137 + %2281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 49, !dbg !137 + %2282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 50, !dbg !137 + %2283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 51, !dbg !137 + %2284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 52, !dbg !137 + %2285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 53, !dbg !137 + %2286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 54, !dbg !137 + %2287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 55, !dbg !137 + %2288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 56, !dbg !137 + %2289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 57, !dbg !137 + %2290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 58, !dbg !137 + %2291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 59, !dbg !137 + %2292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 60, !dbg !137 + %2293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 61, !dbg !137 + %2294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 62, !dbg !137 + %2295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2226, 63, !dbg !137 + %2296 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2232, float %2233, float %2234, float %2235, float %2236, float %2237, float %2238, float %2239, float %2240, float %2241, float %2242, float %2243, float %2244, float %2245, float %2246, float %2247, float %2248, float %2249, float %2250, float %2251, float %2252, float %2253, float %2254, float %2255, float %2256, float %2257, float %2258, float %2259, float %2260, float %2261, float %2262, float %2263, float %2264, float %2265, float %2266, float %2267, float %2268, float %2269, float %2270, float %2271, float %2272, float %2273, float %2274, float %2275, float %2276, float %2277, float %2278, float %2279, float %2280, float %2281, float %2282, float %2283, float %2284, float %2285, float %2286, float %2287, float %2288, float %2289, float %2290, float %2291, float %2292, float %2293, float %2294, float %2295, i32 %2082, i32 %2083, i32 %2084, i32 %2085, i64 %2231, i1 true) #3, !dbg !137 + %2297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 0, !dbg !137 + %2298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 1, !dbg !137 + %2299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 2, !dbg !137 + %2300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 3, !dbg !137 + %2301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 4, !dbg !137 + %2302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 5, !dbg !137 + %2303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 6, !dbg !137 + %2304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 7, !dbg !137 + %2305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 8, !dbg !137 + %2306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 9, !dbg !137 + %2307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 10, !dbg !137 + %2308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 11, !dbg !137 + %2309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 12, !dbg !137 + %2310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 13, !dbg !137 + %2311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 14, !dbg !137 + %2312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 15, !dbg !137 + %2313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 16, !dbg !137 + %2314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 17, !dbg !137 + %2315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 18, !dbg !137 + %2316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 19, !dbg !137 + %2317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 20, !dbg !137 + %2318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 21, !dbg !137 + %2319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 22, !dbg !137 + %2320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 23, !dbg !137 + %2321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 24, !dbg !137 + %2322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 25, !dbg !137 + %2323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 26, !dbg !137 + %2324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 27, !dbg !137 + %2325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 28, !dbg !137 + %2326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 29, !dbg !137 + %2327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 30, !dbg !137 + %2328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 31, !dbg !137 + %2329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 32, !dbg !137 + %2330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 33, !dbg !137 + %2331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 34, !dbg !137 + %2332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 35, !dbg !137 + %2333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 36, !dbg !137 + %2334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 37, !dbg !137 + %2335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 38, !dbg !137 + %2336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 39, !dbg !137 + %2337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 40, !dbg !137 + %2338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 41, !dbg !137 + %2339 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 42, !dbg !137 + %2340 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 43, !dbg !137 + %2341 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 44, !dbg !137 + %2342 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 45, !dbg !137 + %2343 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 46, !dbg !137 + %2344 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 47, !dbg !137 + %2345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 48, !dbg !137 + %2346 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 49, !dbg !137 + %2347 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 50, !dbg !137 + %2348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 51, !dbg !137 + %2349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 52, !dbg !137 + %2350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 53, !dbg !137 + %2351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 54, !dbg !137 + %2352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 55, !dbg !137 + %2353 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 56, !dbg !137 + %2354 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 57, !dbg !137 + %2355 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 58, !dbg !137 + %2356 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 59, !dbg !137 + %2357 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 60, !dbg !137 + %2358 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 61, !dbg !137 + %2359 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 62, !dbg !137 + %2360 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2296, 63, !dbg !137 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !137 + %2361 = insertelement <2 x i32> poison, i32 %542, i64 0, !dbg !107 + %2362 = shufflevector <2 x i32> %2361, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !107 + %2363 = add <2 x i32> %2362, %617, !dbg !107 + %2364 = add <2 x i32> %2362, %616, !dbg !107 + %2365 = add <2 x i32> %2362, %615, !dbg !107 + %2366 = add <2 x i32> %2362, %614, !dbg !107 + %2367 = add <2 x i32> %2362, %613, !dbg !107 + %2368 = add <2 x i32> %2362, %612, !dbg !107 + %2369 = add <2 x i32> %2362, %611, !dbg !107 + %2370 = add <2 x i32> %2362, %610, !dbg !107 + %2371 = add nuw nsw i32 %609, 1, !dbg !102 + %2372 = lshr i32 %2371, 1, !dbg !138 + %2373 = zext nneg i32 %2372 to i64, !dbg !139 + %2374 = getelementptr i32, ptr addrspace(1) %376, i64 %2373, !dbg !139 + %2375 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !140 + %2376 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2374, i64 %2375, i1 %619) #3, !dbg !140 + %2377 = add nuw nsw i32 %2372, 1, !dbg !141 + %2378 = icmp slt i32 %2377, %381, !dbg !142 + %2379 = getelementptr i8, ptr addrspace(1) %2374, i64 4, !dbg !143 + %2380 = and i1 %619, %2378, !dbg !102 + %2381 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !144 + %2382 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2379, i64 %2381, i1 %2380) #3, !dbg !144 + %2383 = and i32 %609, 1, !dbg !145 + %2384 = sub i32 %2382, %2376, !dbg !146 + %2385 = shl i32 %2384, 7, !dbg !147 + %2386 = add i32 %2385, -64, !dbg !148 + %2387 = xor i32 %2383, 1, !dbg !149 + %2388 = mul nuw nsw i32 %2386, %2387, !dbg !149 + %2389 = shl nuw nsw i32 %2383, 6, !dbg !150 + %2390 = add i32 %2388, %2389, !dbg !151 + %2391 = shl i32 %2390, 10, !dbg !152 + %2392 = sext i32 %2391 to i64, !dbg !105 + %2393 = getelementptr bfloat, ptr addrspace(1) %.pn9091615, i64 %2392, !dbg !105 + %2394 = getelementptr bfloat, ptr addrspace(1) %.pn8931616, i64 %2392, !dbg !105 + %2395 = getelementptr bfloat, ptr addrspace(1) %.pn8771617, i64 %2392, !dbg !105 + %2396 = getelementptr bfloat, ptr addrspace(1) %.pn8611618, i64 %2392, !dbg !105 + %2397 = getelementptr bfloat, ptr addrspace(1) %.pn9811623, i64 %2392, !dbg !106 + %2398 = getelementptr bfloat, ptr addrspace(1) %.pn9651624, i64 %2392, !dbg !106 + %2399 = getelementptr bfloat, ptr addrspace(1) %.pn9491625, i64 %2392, !dbg !106 + %2400 = getelementptr bfloat, ptr addrspace(1) %.pn9331626, i64 %2392, !dbg !106 + %2401 = add i32 %2390, %.pn9171619, !dbg !107 + %2402 = add i32 %2390, %.pn9151620, !dbg !107 + %2403 = add i32 %2390, %.pn9131621, !dbg !107 + %2404 = add i32 %2390, %.pn9111622, !dbg !107 + %2405 = add i32 %544, 1, !dbg !102 + %2406 = icmp sgt i32 %2405, 2, !dbg !102 + %2407 = select i1 %2406, i32 0, i32 %2405, !dbg !102 + %2408 = icmp slt i32 %2401, %18, !dbg !103 + %2409 = icmp slt i32 %2402, %18, !dbg !103 + %2410 = icmp slt i32 %2403, %18, !dbg !103 + %2411 = icmp slt i32 %2404, %18, !dbg !103 + %2412 = shl i32 %2407, 13, !dbg !104 + %2413 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2412, !dbg !104 + %2414 = and i1 %618, %2408, !dbg !102 + %2415 = and i1 %618, %2409, !dbg !102 + %2416 = and i1 %618, %2410, !dbg !102 + %2417 = and i1 %618, %2411, !dbg !102 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !104 + %2418 = getelementptr inbounds nuw i8, ptr addrspace(3) %2413, i32 %437, !dbg !104 + %2419 = select i1 %2414, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2418, ptr addrspace(1) %2393, i32 %2419) #3, !dbg !104 + %2420 = getelementptr inbounds nuw i8, ptr addrspace(3) %2413, i32 %440, !dbg !104 + %2421 = select i1 %2415, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2420, ptr addrspace(1) %2394, i32 %2421) #3, !dbg !104 + %2422 = getelementptr inbounds nuw i8, ptr addrspace(3) %2413, i32 %443, !dbg !104 + %2423 = select i1 %2416, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2422, ptr addrspace(1) %2395, i32 %2423) #3, !dbg !104 + %2424 = getelementptr inbounds nuw i8, ptr addrspace(3) %2413, i32 %446, !dbg !104 + %2425 = select i1 %2417, i32 16, i32 0, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2424, ptr addrspace(1) %2396, i32 %2425) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + %2426 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2412, !dbg !104 + %2427 = getelementptr inbounds nuw i8, ptr addrspace(3) %2426, i32 %437, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2427, ptr addrspace(1) %2397, i32 %2419) #3, !dbg !104 + %2428 = getelementptr inbounds nuw i8, ptr addrspace(3) %2426, i32 %440, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2428, ptr addrspace(1) %2398, i32 %2421) #3, !dbg !104 + %2429 = getelementptr inbounds nuw i8, ptr addrspace(3) %2426, i32 %443, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2429, ptr addrspace(1) %2399, i32 %2423) #3, !dbg !104 + %2430 = getelementptr inbounds nuw i8, ptr addrspace(3) %2426, i32 %446, !dbg !104 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2430, ptr addrspace(1) %2400, i32 %2425) #3, !dbg !104 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !104 + %exitcond.not = icmp eq i32 %2371, %smax, !dbg !102 + br i1 %exitcond.not, label %._crit_edge, label %541, !dbg !102 + +._crit_edge: ; preds = %__nv_exp2f.exit1609, %67 + %2431 = phi float [ 0.000000e+00, %67 ], [ %2297, %__nv_exp2f.exit1609 ] + %2432 = phi float [ 0.000000e+00, %67 ], [ %2298, %__nv_exp2f.exit1609 ] + %2433 = phi float [ 0.000000e+00, %67 ], [ %2299, %__nv_exp2f.exit1609 ] + %2434 = phi float [ 0.000000e+00, %67 ], [ %2300, %__nv_exp2f.exit1609 ] + %2435 = phi float [ 0.000000e+00, %67 ], [ %2301, %__nv_exp2f.exit1609 ] + %2436 = phi float [ 0.000000e+00, %67 ], [ %2302, %__nv_exp2f.exit1609 ] + %2437 = phi float [ 0.000000e+00, %67 ], [ %2303, %__nv_exp2f.exit1609 ] + %2438 = phi float [ 0.000000e+00, %67 ], [ %2304, %__nv_exp2f.exit1609 ] + %2439 = phi float [ 0.000000e+00, %67 ], [ %2305, %__nv_exp2f.exit1609 ] + %2440 = phi float [ 0.000000e+00, %67 ], [ %2306, %__nv_exp2f.exit1609 ] + %2441 = phi float [ 0.000000e+00, %67 ], [ %2307, %__nv_exp2f.exit1609 ] + %2442 = phi float [ 0.000000e+00, %67 ], [ %2308, %__nv_exp2f.exit1609 ] + %2443 = phi float [ 0.000000e+00, %67 ], [ %2309, %__nv_exp2f.exit1609 ] + %2444 = phi float [ 0.000000e+00, %67 ], [ %2310, %__nv_exp2f.exit1609 ] + %2445 = phi float [ 0.000000e+00, %67 ], [ %2311, %__nv_exp2f.exit1609 ] + %2446 = phi float [ 0.000000e+00, %67 ], [ %2312, %__nv_exp2f.exit1609 ] + %2447 = phi float [ 0.000000e+00, %67 ], [ %2313, %__nv_exp2f.exit1609 ] + %2448 = phi float [ 0.000000e+00, %67 ], [ %2314, %__nv_exp2f.exit1609 ] + %2449 = phi float [ 0.000000e+00, %67 ], [ %2315, %__nv_exp2f.exit1609 ] + %2450 = phi float [ 0.000000e+00, %67 ], [ %2316, %__nv_exp2f.exit1609 ] + %2451 = phi float [ 0.000000e+00, %67 ], [ %2317, %__nv_exp2f.exit1609 ] + %2452 = phi float [ 0.000000e+00, %67 ], [ %2318, %__nv_exp2f.exit1609 ] + %2453 = phi float [ 0.000000e+00, %67 ], [ %2319, %__nv_exp2f.exit1609 ] + %2454 = phi float [ 0.000000e+00, %67 ], [ %2320, %__nv_exp2f.exit1609 ] + %2455 = phi float [ 0.000000e+00, %67 ], [ %2321, %__nv_exp2f.exit1609 ] + %2456 = phi float [ 0.000000e+00, %67 ], [ %2322, %__nv_exp2f.exit1609 ] + %2457 = phi float [ 0.000000e+00, %67 ], [ %2323, %__nv_exp2f.exit1609 ] + %2458 = phi float [ 0.000000e+00, %67 ], [ %2324, %__nv_exp2f.exit1609 ] + %2459 = phi float [ 0.000000e+00, %67 ], [ %2325, %__nv_exp2f.exit1609 ] + %2460 = phi float [ 0.000000e+00, %67 ], [ %2326, %__nv_exp2f.exit1609 ] + %2461 = phi float [ 0.000000e+00, %67 ], [ %2327, %__nv_exp2f.exit1609 ] + %2462 = phi float [ 0.000000e+00, %67 ], [ %2328, %__nv_exp2f.exit1609 ] + %2463 = phi float [ 0.000000e+00, %67 ], [ %2329, %__nv_exp2f.exit1609 ] + %2464 = phi float [ 0.000000e+00, %67 ], [ %2330, %__nv_exp2f.exit1609 ] + %2465 = phi float [ 0.000000e+00, %67 ], [ %2331, %__nv_exp2f.exit1609 ] + %2466 = phi float [ 0.000000e+00, %67 ], [ %2332, %__nv_exp2f.exit1609 ] + %2467 = phi float [ 0.000000e+00, %67 ], [ %2333, %__nv_exp2f.exit1609 ] + %2468 = phi float [ 0.000000e+00, %67 ], [ %2334, %__nv_exp2f.exit1609 ] + %2469 = phi float [ 0.000000e+00, %67 ], [ %2335, %__nv_exp2f.exit1609 ] + %2470 = phi float [ 0.000000e+00, %67 ], [ %2336, %__nv_exp2f.exit1609 ] + %2471 = phi float [ 0.000000e+00, %67 ], [ %2337, %__nv_exp2f.exit1609 ] + %2472 = phi float [ 0.000000e+00, %67 ], [ %2338, %__nv_exp2f.exit1609 ] + %2473 = phi float [ 0.000000e+00, %67 ], [ %2339, %__nv_exp2f.exit1609 ] + %2474 = phi float [ 0.000000e+00, %67 ], [ %2340, %__nv_exp2f.exit1609 ] + %2475 = phi float [ 0.000000e+00, %67 ], [ %2341, %__nv_exp2f.exit1609 ] + %2476 = phi float [ 0.000000e+00, %67 ], [ %2342, %__nv_exp2f.exit1609 ] + %2477 = phi float [ 0.000000e+00, %67 ], [ %2343, %__nv_exp2f.exit1609 ] + %2478 = phi float [ 0.000000e+00, %67 ], [ %2344, %__nv_exp2f.exit1609 ] + %2479 = phi float [ 0.000000e+00, %67 ], [ %2345, %__nv_exp2f.exit1609 ] + %2480 = phi float [ 0.000000e+00, %67 ], [ %2346, %__nv_exp2f.exit1609 ] + %2481 = phi float [ 0.000000e+00, %67 ], [ %2347, %__nv_exp2f.exit1609 ] + %2482 = phi float [ 0.000000e+00, %67 ], [ %2348, %__nv_exp2f.exit1609 ] + %2483 = phi float [ 0.000000e+00, %67 ], [ %2349, %__nv_exp2f.exit1609 ] + %2484 = phi float [ 0.000000e+00, %67 ], [ %2350, %__nv_exp2f.exit1609 ] + %2485 = phi float [ 0.000000e+00, %67 ], [ %2351, %__nv_exp2f.exit1609 ] + %2486 = phi float [ 0.000000e+00, %67 ], [ %2352, %__nv_exp2f.exit1609 ] + %2487 = phi float [ 0.000000e+00, %67 ], [ %2353, %__nv_exp2f.exit1609 ] + %2488 = phi float [ 0.000000e+00, %67 ], [ %2354, %__nv_exp2f.exit1609 ] + %2489 = phi float [ 0.000000e+00, %67 ], [ %2355, %__nv_exp2f.exit1609 ] + %2490 = phi float [ 0.000000e+00, %67 ], [ %2356, %__nv_exp2f.exit1609 ] + %2491 = phi float [ 0.000000e+00, %67 ], [ %2357, %__nv_exp2f.exit1609 ] + %2492 = phi float [ 0.000000e+00, %67 ], [ %2358, %__nv_exp2f.exit1609 ] + %2493 = phi float [ 0.000000e+00, %67 ], [ %2359, %__nv_exp2f.exit1609 ] + %2494 = phi float [ 0.000000e+00, %67 ], [ %2360, %__nv_exp2f.exit1609 ] + %2495 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %2431, float %2432, float %2433, float %2434, float %2435, float %2436, float %2437, float %2438, float %2439, float %2440, float %2441, float %2442, float %2443, float %2444, float %2445, float %2446, float %2447, float %2448, float %2449, float %2450, float %2451, float %2452, float %2453, float %2454, float %2455, float %2456, float %2457, float %2458, float %2459, float %2460, float %2461, float %2462, float %2463, float %2464, float %2465, float %2466, float %2467, float %2468, float %2469, float %2470, float %2471, float %2472, float %2473, float %2474, float %2475, float %2476, float %2477, float %2478, float %2479, float %2480, float %2481, float %2482, float %2483, float %2484, float %2485, float %2486, float %2487, float %2488, float %2489, float %2490, float %2491, float %2492, float %2493, float %2494) #3, !dbg !102 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !102 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !102 + %2496 = getelementptr i32, ptr addrspace(1) %13, i64 %375, !dbg !153 + %2497 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2496) #3, !dbg !154 + %2498 = shl i32 %2497, 7, !dbg !155 + %2499 = getelementptr i32, ptr addrspace(1) %12, i64 %379, !dbg !156 + %2500 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2499) #3, !dbg !157 + %2501 = or disjoint i32 %2498, %53, !dbg !158 + %2502 = or disjoint i32 %2498, %54, !dbg !158 + %2503 = or disjoint i32 %2498, %55, !dbg !158 + %2504 = or disjoint i32 %2498, %56, !dbg !158 + %2505 = shl i32 %2501, 10, !dbg !159 + %2506 = shl i32 %2502, 10, !dbg !159 + %2507 = shl i32 %2503, 10, !dbg !159 + %2508 = shl i32 %2504, 10, !dbg !159 + %2509 = sext i32 %2505 to i64, !dbg !161 + %2510 = getelementptr bfloat, ptr addrspace(1) %47, i64 %2509, !dbg !161 + %2511 = sext i32 %2506 to i64, !dbg !161 + %2512 = getelementptr bfloat, ptr addrspace(1) %47, i64 %2511, !dbg !161 + %2513 = sext i32 %2507 to i64, !dbg !161 + %2514 = getelementptr bfloat, ptr addrspace(1) %47, i64 %2513, !dbg !161 + %2515 = sext i32 %2508 to i64, !dbg !161 + %2516 = getelementptr bfloat, ptr addrspace(1) %47, i64 %2515, !dbg !161 + %2517 = getelementptr bfloat, ptr addrspace(1) %2510, i64 %130, !dbg !162 + %2518 = getelementptr bfloat, ptr addrspace(1) %2512, i64 %130, !dbg !162 + %2519 = getelementptr bfloat, ptr addrspace(1) %2514, i64 %130, !dbg !162 + %2520 = getelementptr bfloat, ptr addrspace(1) %2516, i64 %130, !dbg !162 + %2521 = getelementptr bfloat, ptr addrspace(1) %48, i64 %2509, !dbg !163 + %2522 = getelementptr bfloat, ptr addrspace(1) %48, i64 %2511, !dbg !163 + %2523 = getelementptr bfloat, ptr addrspace(1) %48, i64 %2513, !dbg !163 + %2524 = getelementptr bfloat, ptr addrspace(1) %48, i64 %2515, !dbg !163 + %2525 = getelementptr bfloat, ptr addrspace(1) %2521, i64 %130, !dbg !164 + %2526 = getelementptr bfloat, ptr addrspace(1) %2522, i64 %130, !dbg !164 + %2527 = getelementptr bfloat, ptr addrspace(1) %2523, i64 %130, !dbg !164 + %2528 = getelementptr bfloat, ptr addrspace(1) %2524, i64 %130, !dbg !164 + %2529 = shl i32 %2500, 1, !dbg !165 + %2530 = tail call i32 @llvm.smin.i32(i32 %2529, i32 %425), !dbg !166 + %2531 = icmp sgt i32 %2529, 0, !dbg !167 + %2532 = icmp slt i32 %2501, %18, !dbg !168 + %2533 = icmp slt i32 %2502, %18, !dbg !168 + %2534 = icmp slt i32 %2503, %18, !dbg !168 + %2535 = icmp slt i32 %2504, %18, !dbg !168 + %2536 = and i1 %2531, %2532, !dbg !167 + %2537 = and i1 %2531, %2533, !dbg !167 + %2538 = and i1 %2531, %2534, !dbg !167 + %2539 = and i1 %2531, %2535, !dbg !167 + %2540 = select i1 %2536, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %438, ptr addrspace(1) %2517, i32 %2540) #3, !dbg !169 + %2541 = select i1 %2537, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %441, ptr addrspace(1) %2518, i32 %2541) #3, !dbg !169 + %2542 = select i1 %2538, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %444, ptr addrspace(1) %2519, i32 %2542) #3, !dbg !169 + %2543 = select i1 %2539, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %447, ptr addrspace(1) %2520, i32 %2543) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %449, ptr addrspace(1) %2525, i32 %2540) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %450, ptr addrspace(1) %2526, i32 %2541) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %451, ptr addrspace(1) %2527, i32 %2542) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %452, ptr addrspace(1) %2528, i32 %2543) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + %2544 = icmp sgt i32 %2530, 1, !dbg !167 + %2545 = getelementptr i8, ptr addrspace(1) %2517, i64 131072, !dbg !170 + %2546 = getelementptr i8, ptr addrspace(1) %2518, i64 131072, !dbg !170 + %2547 = getelementptr i8, ptr addrspace(1) %2519, i64 131072, !dbg !170 + %2548 = getelementptr i8, ptr addrspace(1) %2520, i64 131072, !dbg !170 + %2549 = getelementptr i8, ptr addrspace(1) %2525, i64 131072, !dbg !171 + %2550 = getelementptr i8, ptr addrspace(1) %2526, i64 131072, !dbg !171 + %2551 = getelementptr i8, ptr addrspace(1) %2527, i64 131072, !dbg !171 + %2552 = getelementptr i8, ptr addrspace(1) %2528, i64 131072, !dbg !171 + %2553 = or disjoint i32 %2501, 64, !dbg !172 + %2554 = or disjoint i32 %2502, 64, !dbg !172 + %2555 = or disjoint i32 %2503, 64, !dbg !172 + %2556 = or disjoint i32 %2504, 64, !dbg !172 + %2557 = icmp slt i32 %2553, %18, !dbg !168 + %2558 = icmp slt i32 %2554, %18, !dbg !168 + %2559 = icmp slt i32 %2555, %18, !dbg !168 + %2560 = icmp slt i32 %2556, %18, !dbg !168 + %2561 = and i1 %2544, %2557, !dbg !167 + %2562 = and i1 %2544, %2558, !dbg !167 + %2563 = and i1 %2544, %2559, !dbg !167 + %2564 = and i1 %2544, %2560, !dbg !167 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !169 + %2565 = select i1 %2561, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %474, ptr addrspace(1) %2545, i32 %2565) #3, !dbg !169 + %2566 = select i1 %2562, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %476, ptr addrspace(1) %2546, i32 %2566) #3, !dbg !169 + %2567 = select i1 %2563, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %478, ptr addrspace(1) %2547, i32 %2567) #3, !dbg !169 + %2568 = select i1 %2564, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %480, ptr addrspace(1) %2548, i32 %2568) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %482, ptr addrspace(1) %2549, i32 %2565) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %483, ptr addrspace(1) %2550, i32 %2566) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %484, ptr addrspace(1) %2551, i32 %2567) #3, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %485, ptr addrspace(1) %2552, i32 %2568) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !173 + %2569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 0, !dbg !167 + %2570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 1, !dbg !167 + %2571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 2, !dbg !167 + %2572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 3, !dbg !167 + %2573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 4, !dbg !167 + %2574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 5, !dbg !167 + %2575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 6, !dbg !167 + %2576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 7, !dbg !167 + %2577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 8, !dbg !167 + %2578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 9, !dbg !167 + %2579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 10, !dbg !167 + %2580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 11, !dbg !167 + %2581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 12, !dbg !167 + %2582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 13, !dbg !167 + %2583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 14, !dbg !167 + %2584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 15, !dbg !167 + %2585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 16, !dbg !167 + %2586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 17, !dbg !167 + %2587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 18, !dbg !167 + %2588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 19, !dbg !167 + %2589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 20, !dbg !167 + %2590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 21, !dbg !167 + %2591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 22, !dbg !167 + %2592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 23, !dbg !167 + %2593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 24, !dbg !167 + %2594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 25, !dbg !167 + %2595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 26, !dbg !167 + %2596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 27, !dbg !167 + %2597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 28, !dbg !167 + %2598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 29, !dbg !167 + %2599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 30, !dbg !167 + %2600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 31, !dbg !167 + %2601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 32, !dbg !167 + %2602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 33, !dbg !167 + %2603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 34, !dbg !167 + %2604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 35, !dbg !167 + %2605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 36, !dbg !167 + %2606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 37, !dbg !167 + %2607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 38, !dbg !167 + %2608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 39, !dbg !167 + %2609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 40, !dbg !167 + %2610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 41, !dbg !167 + %2611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 42, !dbg !167 + %2612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 43, !dbg !167 + %2613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 44, !dbg !167 + %2614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 45, !dbg !167 + %2615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 46, !dbg !167 + %2616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 47, !dbg !167 + %2617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 48, !dbg !167 + %2618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 49, !dbg !167 + %2619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 50, !dbg !167 + %2620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 51, !dbg !167 + %2621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 52, !dbg !167 + %2622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 53, !dbg !167 + %2623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 54, !dbg !167 + %2624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 55, !dbg !167 + %2625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 56, !dbg !167 + %2626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 57, !dbg !167 + %2627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 58, !dbg !167 + %2628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 59, !dbg !167 + %2629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 60, !dbg !167 + %2630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 61, !dbg !167 + %2631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 62, !dbg !167 + %2632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 63, !dbg !167 + br i1 %2531, label %.lr.ph1672, label %._crit_edge1673, !dbg !167 + +.lr.ph1672: ; preds = %._crit_edge + %2633 = insertelement <16 x i32> poison, i32 %2498, i64 0, !dbg !158 + %2634 = shufflevector <16 x i32> %2633, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !158 + %2635 = shufflevector <2 x i32> %387, <2 x i32> poison, <16 x i32> , !dbg !158 + %2636 = insertelement <16 x i32> %2635, i32 %384, i64 14, !dbg !158 + %2637 = insertelement <16 x i32> %2636, i32 %383, i64 15, !dbg !158 + %2638 = shufflevector <8 x i32> %393, <8 x i32> poison, <16 x i32> , !dbg !158 + %2639 = shufflevector <16 x i32> %2638, <16 x i32> %2637, <16 x i32> , !dbg !158 + %2640 = shufflevector <4 x i32> %390, <4 x i32> poison, <16 x i32> , !dbg !158 + %2641 = shufflevector <16 x i32> %2639, <16 x i32> %2640, <16 x i32> , !dbg !158 + %2642 = or disjoint <16 x i32> %2634, %2641, !dbg !158 + %2643 = add nsw i32 %2530, -2 + %2644 = add nsw i32 %2530, -1 + %smax2263 = tail call i32 @llvm.smax.i32(i32 %2530, i32 1), !dbg !167 + %2645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 0, !dbg !167 + %2646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 1, !dbg !167 + %2647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 2, !dbg !167 + %2648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 3, !dbg !167 + %2649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 4, !dbg !167 + %2650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 5, !dbg !167 + %2651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 6, !dbg !167 + %2652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 7, !dbg !167 + %2653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 8, !dbg !167 + %2654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 9, !dbg !167 + %2655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 10, !dbg !167 + %2656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 11, !dbg !167 + %2657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 12, !dbg !167 + %2658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 13, !dbg !167 + %2659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 14, !dbg !167 + %2660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 15, !dbg !167 + %2661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 16, !dbg !167 + %2662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 17, !dbg !167 + %2663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 18, !dbg !167 + %2664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 19, !dbg !167 + %2665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 20, !dbg !167 + %2666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 21, !dbg !167 + %2667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 22, !dbg !167 + %2668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 23, !dbg !167 + %2669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 24, !dbg !167 + %2670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 25, !dbg !167 + %2671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 26, !dbg !167 + %2672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 27, !dbg !167 + %2673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 28, !dbg !167 + %2674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 29, !dbg !167 + %2675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 30, !dbg !167 + %2676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 31, !dbg !167 + %2677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 32, !dbg !167 + %2678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 33, !dbg !167 + %2679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 34, !dbg !167 + %2680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 35, !dbg !167 + %2681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 36, !dbg !167 + %2682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 37, !dbg !167 + %2683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 38, !dbg !167 + %2684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 39, !dbg !167 + %2685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 40, !dbg !167 + %2686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 41, !dbg !167 + %2687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 42, !dbg !167 + %2688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 43, !dbg !167 + %2689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 44, !dbg !167 + %2690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 45, !dbg !167 + %2691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 46, !dbg !167 + %2692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 47, !dbg !167 + %2693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 48, !dbg !167 + %2694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 49, !dbg !167 + %2695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 50, !dbg !167 + %2696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 51, !dbg !167 + %2697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 52, !dbg !167 + %2698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 53, !dbg !167 + %2699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 54, !dbg !167 + %2700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 55, !dbg !167 + %2701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 56, !dbg !167 + %2702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 57, !dbg !167 + %2703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 58, !dbg !167 + %2704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 59, !dbg !167 + %2705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 60, !dbg !167 + %2706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 61, !dbg !167 + %2707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 62, !dbg !167 + %2708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2495, 63, !dbg !167 + br label %2709, !dbg !167 + +2709: ; preds = %.lr.ph1672, %__nv_exp2f.exit1513 + %2710 = phi i32 [ 64, %.lr.ph1672 ], [ %4280, %__nv_exp2f.exit1513 ] + %2711 = phi i32 [ -1, %.lr.ph1672 ], [ %2719, %__nv_exp2f.exit1513 ] + %2712 = phi i32 [ 1, %.lr.ph1672 ], [ %4297, %__nv_exp2f.exit1513 ] + %.pn11011654 = phi ptr addrspace(1) [ %2552, %.lr.ph1672 ], [ %4290, %__nv_exp2f.exit1513 ] + %.pn11171653 = phi ptr addrspace(1) [ %2551, %.lr.ph1672 ], [ %4289, %__nv_exp2f.exit1513 ] + %.pn11331652 = phi ptr addrspace(1) [ %2550, %.lr.ph1672 ], [ %4288, %__nv_exp2f.exit1513 ] + %.pn11491651 = phi ptr addrspace(1) [ %2549, %.lr.ph1672 ], [ %4287, %__nv_exp2f.exit1513 ] + %.pn10791650 = phi i32 [ %2556, %.lr.ph1672 ], [ %4294, %__nv_exp2f.exit1513 ] + %.pn10811649 = phi i32 [ %2555, %.lr.ph1672 ], [ %4293, %__nv_exp2f.exit1513 ] + %.pn10831648 = phi i32 [ %2554, %.lr.ph1672 ], [ %4292, %__nv_exp2f.exit1513 ] + %.pn10851647 = phi i32 [ %2553, %.lr.ph1672 ], [ %4291, %__nv_exp2f.exit1513 ] + %.pn10291646 = phi ptr addrspace(1) [ %2548, %.lr.ph1672 ], [ %4286, %__nv_exp2f.exit1513 ] + %.pn10451645 = phi ptr addrspace(1) [ %2547, %.lr.ph1672 ], [ %4285, %__nv_exp2f.exit1513 ] + %.pn10611644 = phi ptr addrspace(1) [ %2546, %.lr.ph1672 ], [ %4284, %__nv_exp2f.exit1513 ] + %.pn10771643 = phi ptr addrspace(1) [ %2545, %.lr.ph1672 ], [ %4283, %__nv_exp2f.exit1513 ] + %.pn = phi float [ %2645, %.lr.ph1672 ], [ %4194, %__nv_exp2f.exit1513 ] + %.pn2530 = phi float [ %2646, %.lr.ph1672 ], [ %4195, %__nv_exp2f.exit1513 ] + %.pn2531 = phi float [ %2647, %.lr.ph1672 ], [ %4196, %__nv_exp2f.exit1513 ] + %.pn2532 = phi float [ %2648, %.lr.ph1672 ], [ %4197, %__nv_exp2f.exit1513 ] + %.pn2533 = phi float [ %2649, %.lr.ph1672 ], [ %4198, %__nv_exp2f.exit1513 ] + %.pn2534 = phi float [ %2650, %.lr.ph1672 ], [ %4199, %__nv_exp2f.exit1513 ] + %.pn2535 = phi float [ %2651, %.lr.ph1672 ], [ %4200, %__nv_exp2f.exit1513 ] + %.pn2536 = phi float [ %2652, %.lr.ph1672 ], [ %4201, %__nv_exp2f.exit1513 ] + %.pn2537 = phi float [ %2653, %.lr.ph1672 ], [ %4202, %__nv_exp2f.exit1513 ] + %.pn2538 = phi float [ %2654, %.lr.ph1672 ], [ %4203, %__nv_exp2f.exit1513 ] + %.pn2539 = phi float [ %2655, %.lr.ph1672 ], [ %4204, %__nv_exp2f.exit1513 ] + %.pn2540 = phi float [ %2656, %.lr.ph1672 ], [ %4205, %__nv_exp2f.exit1513 ] + %.pn2541 = phi float [ %2657, %.lr.ph1672 ], [ %4206, %__nv_exp2f.exit1513 ] + %.pn2542 = phi float [ %2658, %.lr.ph1672 ], [ %4207, %__nv_exp2f.exit1513 ] + %.pn2543 = phi float [ %2659, %.lr.ph1672 ], [ %4208, %__nv_exp2f.exit1513 ] + %.pn2544 = phi float [ %2660, %.lr.ph1672 ], [ %4209, %__nv_exp2f.exit1513 ] + %.pn2545 = phi float [ %2661, %.lr.ph1672 ], [ %4210, %__nv_exp2f.exit1513 ] + %.pn2546 = phi float [ %2662, %.lr.ph1672 ], [ %4211, %__nv_exp2f.exit1513 ] + %.pn2547 = phi float [ %2663, %.lr.ph1672 ], [ %4212, %__nv_exp2f.exit1513 ] + %.pn2548 = phi float [ %2664, %.lr.ph1672 ], [ %4213, %__nv_exp2f.exit1513 ] + %.pn2549 = phi float [ %2665, %.lr.ph1672 ], [ %4214, %__nv_exp2f.exit1513 ] + %.pn2550 = phi float [ %2666, %.lr.ph1672 ], [ %4215, %__nv_exp2f.exit1513 ] + %.pn2551 = phi float [ %2667, %.lr.ph1672 ], [ %4216, %__nv_exp2f.exit1513 ] + %.pn2552 = phi float [ %2668, %.lr.ph1672 ], [ %4217, %__nv_exp2f.exit1513 ] + %.pn2553 = phi float [ %2669, %.lr.ph1672 ], [ %4218, %__nv_exp2f.exit1513 ] + %.pn2554 = phi float [ %2670, %.lr.ph1672 ], [ %4219, %__nv_exp2f.exit1513 ] + %.pn2555 = phi float [ %2671, %.lr.ph1672 ], [ %4220, %__nv_exp2f.exit1513 ] + %.pn2556 = phi float [ %2672, %.lr.ph1672 ], [ %4221, %__nv_exp2f.exit1513 ] + %.pn2557 = phi float [ %2673, %.lr.ph1672 ], [ %4222, %__nv_exp2f.exit1513 ] + %.pn2558 = phi float [ %2674, %.lr.ph1672 ], [ %4223, %__nv_exp2f.exit1513 ] + %.pn2559 = phi float [ %2675, %.lr.ph1672 ], [ %4224, %__nv_exp2f.exit1513 ] + %.pn2560 = phi float [ %2676, %.lr.ph1672 ], [ %4225, %__nv_exp2f.exit1513 ] + %.pn2561 = phi float [ %2677, %.lr.ph1672 ], [ %4226, %__nv_exp2f.exit1513 ] + %.pn2562 = phi float [ %2678, %.lr.ph1672 ], [ %4227, %__nv_exp2f.exit1513 ] + %.pn2563 = phi float [ %2679, %.lr.ph1672 ], [ %4228, %__nv_exp2f.exit1513 ] + %.pn2564 = phi float [ %2680, %.lr.ph1672 ], [ %4229, %__nv_exp2f.exit1513 ] + %.pn2565 = phi float [ %2681, %.lr.ph1672 ], [ %4230, %__nv_exp2f.exit1513 ] + %.pn2566 = phi float [ %2682, %.lr.ph1672 ], [ %4231, %__nv_exp2f.exit1513 ] + %.pn2567 = phi float [ %2683, %.lr.ph1672 ], [ %4232, %__nv_exp2f.exit1513 ] + %.pn2568 = phi float [ %2684, %.lr.ph1672 ], [ %4233, %__nv_exp2f.exit1513 ] + %.pn2569 = phi float [ %2685, %.lr.ph1672 ], [ %4234, %__nv_exp2f.exit1513 ] + %.pn2570 = phi float [ %2686, %.lr.ph1672 ], [ %4235, %__nv_exp2f.exit1513 ] + %.pn2571 = phi float [ %2687, %.lr.ph1672 ], [ %4236, %__nv_exp2f.exit1513 ] + %.pn2572 = phi float [ %2688, %.lr.ph1672 ], [ %4237, %__nv_exp2f.exit1513 ] + %.pn2573 = phi float [ %2689, %.lr.ph1672 ], [ %4238, %__nv_exp2f.exit1513 ] + %.pn2574 = phi float [ %2690, %.lr.ph1672 ], [ %4239, %__nv_exp2f.exit1513 ] + %.pn2575 = phi float [ %2691, %.lr.ph1672 ], [ %4240, %__nv_exp2f.exit1513 ] + %.pn2576 = phi float [ %2692, %.lr.ph1672 ], [ %4241, %__nv_exp2f.exit1513 ] + %.pn2577 = phi float [ %2693, %.lr.ph1672 ], [ %4242, %__nv_exp2f.exit1513 ] + %.pn2578 = phi float [ %2694, %.lr.ph1672 ], [ %4243, %__nv_exp2f.exit1513 ] + %.pn2579 = phi float [ %2695, %.lr.ph1672 ], [ %4244, %__nv_exp2f.exit1513 ] + %.pn2580 = phi float [ %2696, %.lr.ph1672 ], [ %4245, %__nv_exp2f.exit1513 ] + %.pn2581 = phi float [ %2697, %.lr.ph1672 ], [ %4246, %__nv_exp2f.exit1513 ] + %.pn2582 = phi float [ %2698, %.lr.ph1672 ], [ %4247, %__nv_exp2f.exit1513 ] + %.pn2583 = phi float [ %2699, %.lr.ph1672 ], [ %4248, %__nv_exp2f.exit1513 ] + %.pn2584 = phi float [ %2700, %.lr.ph1672 ], [ %4249, %__nv_exp2f.exit1513 ] + %.pn2585 = phi float [ %2701, %.lr.ph1672 ], [ %4250, %__nv_exp2f.exit1513 ] + %.pn2586 = phi float [ %2702, %.lr.ph1672 ], [ %4251, %__nv_exp2f.exit1513 ] + %.pn2587 = phi float [ %2703, %.lr.ph1672 ], [ %4252, %__nv_exp2f.exit1513 ] + %.pn2588 = phi float [ %2704, %.lr.ph1672 ], [ %4253, %__nv_exp2f.exit1513 ] + %.pn2589 = phi float [ %2705, %.lr.ph1672 ], [ %4254, %__nv_exp2f.exit1513 ] + %.pn2590 = phi float [ %2706, %.lr.ph1672 ], [ %4255, %__nv_exp2f.exit1513 ] + %.pn2591 = phi float [ %2707, %.lr.ph1672 ], [ %4256, %__nv_exp2f.exit1513 ] + %.pn2592 = phi float [ %2708, %.lr.ph1672 ], [ %4257, %__nv_exp2f.exit1513 ] + %2713 = phi i32 [ 0, %.lr.ph1672 ], [ %4261, %__nv_exp2f.exit1513 ] + %2714 = phi <16 x i32> [ %2642, %.lr.ph1672 ], [ %4260, %__nv_exp2f.exit1513 ] + %2715 = icmp slt i32 %2713, %2643, !dbg !167 + %2716 = icmp slt i32 %2713, %2644, !dbg !167 + %2717 = add i32 %2711, 1, !dbg !167 + %2718 = icmp sgt i32 %2717, 2, !dbg !167 + %2719 = select i1 %2718, i32 0, i32 %2717, !dbg !167 + %2720 = extractelement <16 x i32> %2714, i64 15, !dbg !168 + %2721 = icmp slt i32 %2720, %18, !dbg !168 + %2722 = extractelement <16 x i32> %2714, i64 14, !dbg !168 + %2723 = icmp slt i32 %2722, %18, !dbg !168 + %2724 = extractelement <16 x i32> %2714, i64 13, !dbg !168 + %2725 = icmp slt i32 %2724, %18, !dbg !168 + %2726 = extractelement <16 x i32> %2714, i64 12, !dbg !168 + %2727 = icmp slt i32 %2726, %18, !dbg !168 + %2728 = extractelement <16 x i32> %2714, i64 11, !dbg !168 + %2729 = icmp slt i32 %2728, %18, !dbg !168 + %2730 = extractelement <16 x i32> %2714, i64 10, !dbg !168 + %2731 = icmp slt i32 %2730, %18, !dbg !168 + %2732 = extractelement <16 x i32> %2714, i64 9, !dbg !168 + %2733 = icmp slt i32 %2732, %18, !dbg !168 + %2734 = extractelement <16 x i32> %2714, i64 8, !dbg !168 + %2735 = icmp slt i32 %2734, %18, !dbg !168 + %2736 = extractelement <16 x i32> %2714, i64 7, !dbg !168 + %2737 = icmp slt i32 %2736, %18, !dbg !168 + %2738 = extractelement <16 x i32> %2714, i64 6, !dbg !168 + %2739 = icmp slt i32 %2738, %18, !dbg !168 + %2740 = extractelement <16 x i32> %2714, i64 5, !dbg !168 + %2741 = icmp slt i32 %2740, %18, !dbg !168 + %2742 = extractelement <16 x i32> %2714, i64 4, !dbg !168 + %2743 = icmp slt i32 %2742, %18, !dbg !168 + %2744 = extractelement <16 x i32> %2714, i64 3, !dbg !168 + %2745 = icmp slt i32 %2744, %18, !dbg !168 + %2746 = extractelement <16 x i32> %2714, i64 2, !dbg !168 + %2747 = icmp slt i32 %2746, %18, !dbg !168 + %2748 = extractelement <16 x i32> %2714, i64 1, !dbg !168 + %2749 = icmp slt i32 %2748, %18, !dbg !168 + %2750 = extractelement <16 x i32> %2714, i64 0, !dbg !168 + %2751 = icmp slt i32 %2750, %18, !dbg !168 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !169 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !169 + %2752 = shl i32 %2719, 13, !dbg !169 + %2753 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2752, !dbg !169 + %2754 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %51, i32 0, i32 31), !dbg !173 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !173 + %2755 = shl i32 %2754, 11, !dbg !173 + %2756 = and i32 %2755, 8192, !dbg !173 + %2757 = add i32 %2756, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2758 = lshr exact i32 %2757, 4, !dbg !173 + %2759 = and i32 %2758, 16383, !dbg !173 + %2760 = zext nneg i32 %2759 to i64, !dbg !173 + %2761 = or disjoint i64 %2760, 4611686293372403712, !dbg !173 + %2762 = ptrtoint ptr addrspace(3) %2753 to i32, !dbg !173 + %2763 = lshr exact i32 %2762, 4, !dbg !173 + %2764 = and i32 %2763, 16383, !dbg !173 + %2765 = zext nneg i32 %2764 to i64, !dbg !173 + %2766 = or disjoint i64 %2765, 4611686293338849280, !dbg !173 + %2767 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %2761, i64 %2766) #3, !dbg !173 + %2768 = or disjoint i32 %2756, 32, !dbg !173 + %2769 = add i32 %2768, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2770 = lshr exact i32 %2769, 4, !dbg !173 + %2771 = and i32 %2770, 16383, !dbg !173 + %2772 = zext nneg i32 %2771 to i64, !dbg !173 + %2773 = or disjoint i64 %2772, 4611686293372403712, !dbg !173 + %2774 = add i32 %2762, 32, !dbg !173 + %2775 = lshr exact i32 %2774, 4, !dbg !173 + %2776 = and i32 %2775, 16383, !dbg !173 + %2777 = zext nneg i32 %2776 to i64, !dbg !173 + %2778 = or disjoint i64 %2777, 4611686293338849280, !dbg !173 + %2779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 0, !dbg !173 + %2780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 1, !dbg !173 + %2781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 2, !dbg !173 + %2782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 3, !dbg !173 + %2783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 4, !dbg !173 + %2784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 5, !dbg !173 + %2785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 6, !dbg !173 + %2786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 7, !dbg !173 + %2787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 8, !dbg !173 + %2788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 9, !dbg !173 + %2789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 10, !dbg !173 + %2790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 11, !dbg !173 + %2791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 12, !dbg !173 + %2792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 13, !dbg !173 + %2793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 14, !dbg !173 + %2794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 15, !dbg !173 + %2795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 16, !dbg !173 + %2796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 17, !dbg !173 + %2797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 18, !dbg !173 + %2798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 19, !dbg !173 + %2799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 20, !dbg !173 + %2800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 21, !dbg !173 + %2801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 22, !dbg !173 + %2802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 23, !dbg !173 + %2803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 24, !dbg !173 + %2804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 25, !dbg !173 + %2805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 26, !dbg !173 + %2806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 27, !dbg !173 + %2807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 28, !dbg !173 + %2808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 29, !dbg !173 + %2809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 30, !dbg !173 + %2810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2767, 31, !dbg !173 + %2811 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2779, float %2780, float %2781, float %2782, float %2783, float %2784, float %2785, float %2786, float %2787, float %2788, float %2789, float %2790, float %2791, float %2792, float %2793, float %2794, float %2795, float %2796, float %2797, float %2798, float %2799, float %2800, float %2801, float %2802, float %2803, float %2804, float %2805, float %2806, float %2807, float %2808, float %2809, float %2810, i64 %2773, i64 %2778, i1 true) #3, !dbg !173 + %2812 = or disjoint i32 %2756, 64, !dbg !173 + %2813 = add i32 %2812, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2814 = lshr exact i32 %2813, 4, !dbg !173 + %2815 = and i32 %2814, 16383, !dbg !173 + %2816 = zext nneg i32 %2815 to i64, !dbg !173 + %2817 = or disjoint i64 %2816, 4611686293372403712, !dbg !173 + %2818 = add i32 %2762, 64, !dbg !173 + %2819 = lshr exact i32 %2818, 4, !dbg !173 + %2820 = and i32 %2819, 16383, !dbg !173 + %2821 = zext nneg i32 %2820 to i64, !dbg !173 + %2822 = or disjoint i64 %2821, 4611686293338849280, !dbg !173 + %2823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 0, !dbg !173 + %2824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 1, !dbg !173 + %2825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 2, !dbg !173 + %2826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 3, !dbg !173 + %2827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 4, !dbg !173 + %2828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 5, !dbg !173 + %2829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 6, !dbg !173 + %2830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 7, !dbg !173 + %2831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 8, !dbg !173 + %2832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 9, !dbg !173 + %2833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 10, !dbg !173 + %2834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 11, !dbg !173 + %2835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 12, !dbg !173 + %2836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 13, !dbg !173 + %2837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 14, !dbg !173 + %2838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 15, !dbg !173 + %2839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 16, !dbg !173 + %2840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 17, !dbg !173 + %2841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 18, !dbg !173 + %2842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 19, !dbg !173 + %2843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 20, !dbg !173 + %2844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 21, !dbg !173 + %2845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 22, !dbg !173 + %2846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 23, !dbg !173 + %2847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 24, !dbg !173 + %2848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 25, !dbg !173 + %2849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 26, !dbg !173 + %2850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 27, !dbg !173 + %2851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 28, !dbg !173 + %2852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 29, !dbg !173 + %2853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 30, !dbg !173 + %2854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2811, 31, !dbg !173 + %2855 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2823, float %2824, float %2825, float %2826, float %2827, float %2828, float %2829, float %2830, float %2831, float %2832, float %2833, float %2834, float %2835, float %2836, float %2837, float %2838, float %2839, float %2840, float %2841, float %2842, float %2843, float %2844, float %2845, float %2846, float %2847, float %2848, float %2849, float %2850, float %2851, float %2852, float %2853, float %2854, i64 %2817, i64 %2822, i1 true) #3, !dbg !173 + %2856 = or disjoint i32 %2756, 96, !dbg !173 + %2857 = add i32 %2856, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2858 = lshr exact i32 %2857, 4, !dbg !173 + %2859 = and i32 %2858, 16383, !dbg !173 + %2860 = zext nneg i32 %2859 to i64, !dbg !173 + %2861 = or disjoint i64 %2860, 4611686293372403712, !dbg !173 + %2862 = add i32 %2762, 96, !dbg !173 + %2863 = lshr exact i32 %2862, 4, !dbg !173 + %2864 = and i32 %2863, 16383, !dbg !173 + %2865 = zext nneg i32 %2864 to i64, !dbg !173 + %2866 = or disjoint i64 %2865, 4611686293338849280, !dbg !173 + %2867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 0, !dbg !173 + %2868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 1, !dbg !173 + %2869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 2, !dbg !173 + %2870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 3, !dbg !173 + %2871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 4, !dbg !173 + %2872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 5, !dbg !173 + %2873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 6, !dbg !173 + %2874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 7, !dbg !173 + %2875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 8, !dbg !173 + %2876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 9, !dbg !173 + %2877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 10, !dbg !173 + %2878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 11, !dbg !173 + %2879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 12, !dbg !173 + %2880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 13, !dbg !173 + %2881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 14, !dbg !173 + %2882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 15, !dbg !173 + %2883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 16, !dbg !173 + %2884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 17, !dbg !173 + %2885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 18, !dbg !173 + %2886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 19, !dbg !173 + %2887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 20, !dbg !173 + %2888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 21, !dbg !173 + %2889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 22, !dbg !173 + %2890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 23, !dbg !173 + %2891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 24, !dbg !173 + %2892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 25, !dbg !173 + %2893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 26, !dbg !173 + %2894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 27, !dbg !173 + %2895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 28, !dbg !173 + %2896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 29, !dbg !173 + %2897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 30, !dbg !173 + %2898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2855, 31, !dbg !173 + %2899 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2867, float %2868, float %2869, float %2870, float %2871, float %2872, float %2873, float %2874, float %2875, float %2876, float %2877, float %2878, float %2879, float %2880, float %2881, float %2882, float %2883, float %2884, float %2885, float %2886, float %2887, float %2888, float %2889, float %2890, float %2891, float %2892, float %2893, float %2894, float %2895, float %2896, float %2897, float %2898, i64 %2861, i64 %2866, i1 true) #3, !dbg !173 + %2900 = or disjoint i32 %2756, 16384, !dbg !173 + %2901 = add i32 %2900, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2902 = lshr exact i32 %2901, 4, !dbg !173 + %2903 = and i32 %2902, 16383, !dbg !173 + %2904 = zext nneg i32 %2903 to i64, !dbg !173 + %2905 = or disjoint i64 %2904, 4611686293372403712, !dbg !173 + %2906 = add i32 %2762, 8192, !dbg !173 + %2907 = lshr exact i32 %2906, 4, !dbg !173 + %2908 = and i32 %2907, 16383, !dbg !173 + %2909 = zext nneg i32 %2908 to i64, !dbg !173 + %2910 = or disjoint i64 %2909, 4611686293338849280, !dbg !173 + %2911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 0, !dbg !173 + %2912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 1, !dbg !173 + %2913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 2, !dbg !173 + %2914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 3, !dbg !173 + %2915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 4, !dbg !173 + %2916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 5, !dbg !173 + %2917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 6, !dbg !173 + %2918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 7, !dbg !173 + %2919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 8, !dbg !173 + %2920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 9, !dbg !173 + %2921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 10, !dbg !173 + %2922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 11, !dbg !173 + %2923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 12, !dbg !173 + %2924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 13, !dbg !173 + %2925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 14, !dbg !173 + %2926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 15, !dbg !173 + %2927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 16, !dbg !173 + %2928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 17, !dbg !173 + %2929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 18, !dbg !173 + %2930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 19, !dbg !173 + %2931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 20, !dbg !173 + %2932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 21, !dbg !173 + %2933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 22, !dbg !173 + %2934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 23, !dbg !173 + %2935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 24, !dbg !173 + %2936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 25, !dbg !173 + %2937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 26, !dbg !173 + %2938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 27, !dbg !173 + %2939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 28, !dbg !173 + %2940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 29, !dbg !173 + %2941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 30, !dbg !173 + %2942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2899, 31, !dbg !173 + %2943 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2911, float %2912, float %2913, float %2914, float %2915, float %2916, float %2917, float %2918, float %2919, float %2920, float %2921, float %2922, float %2923, float %2924, float %2925, float %2926, float %2927, float %2928, float %2929, float %2930, float %2931, float %2932, float %2933, float %2934, float %2935, float %2936, float %2937, float %2938, float %2939, float %2940, float %2941, float %2942, i64 %2905, i64 %2910, i1 true) #3, !dbg !173 + %2944 = or disjoint i32 %2756, 16416, !dbg !173 + %2945 = add i32 %2944, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2946 = lshr exact i32 %2945, 4, !dbg !173 + %2947 = and i32 %2946, 16383, !dbg !173 + %2948 = zext nneg i32 %2947 to i64, !dbg !173 + %2949 = or disjoint i64 %2948, 4611686293372403712, !dbg !173 + %2950 = add i32 %2762, 8224, !dbg !173 + %2951 = lshr exact i32 %2950, 4, !dbg !173 + %2952 = and i32 %2951, 16383, !dbg !173 + %2953 = zext nneg i32 %2952 to i64, !dbg !173 + %2954 = or disjoint i64 %2953, 4611686293338849280, !dbg !173 + %2955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 0, !dbg !173 + %2956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 1, !dbg !173 + %2957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 2, !dbg !173 + %2958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 3, !dbg !173 + %2959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 4, !dbg !173 + %2960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 5, !dbg !173 + %2961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 6, !dbg !173 + %2962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 7, !dbg !173 + %2963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 8, !dbg !173 + %2964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 9, !dbg !173 + %2965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 10, !dbg !173 + %2966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 11, !dbg !173 + %2967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 12, !dbg !173 + %2968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 13, !dbg !173 + %2969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 14, !dbg !173 + %2970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 15, !dbg !173 + %2971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 16, !dbg !173 + %2972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 17, !dbg !173 + %2973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 18, !dbg !173 + %2974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 19, !dbg !173 + %2975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 20, !dbg !173 + %2976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 21, !dbg !173 + %2977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 22, !dbg !173 + %2978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 23, !dbg !173 + %2979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 24, !dbg !173 + %2980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 25, !dbg !173 + %2981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 26, !dbg !173 + %2982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 27, !dbg !173 + %2983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 28, !dbg !173 + %2984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 29, !dbg !173 + %2985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 30, !dbg !173 + %2986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2943, 31, !dbg !173 + %2987 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2955, float %2956, float %2957, float %2958, float %2959, float %2960, float %2961, float %2962, float %2963, float %2964, float %2965, float %2966, float %2967, float %2968, float %2969, float %2970, float %2971, float %2972, float %2973, float %2974, float %2975, float %2976, float %2977, float %2978, float %2979, float %2980, float %2981, float %2982, float %2983, float %2984, float %2985, float %2986, i64 %2949, i64 %2954, i1 true) #3, !dbg !173 + %2988 = or disjoint i32 %2756, 16448, !dbg !173 + %2989 = add i32 %2988, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %2990 = lshr exact i32 %2989, 4, !dbg !173 + %2991 = and i32 %2990, 16383, !dbg !173 + %2992 = zext nneg i32 %2991 to i64, !dbg !173 + %2993 = or disjoint i64 %2992, 4611686293372403712, !dbg !173 + %2994 = add i32 %2762, 8256, !dbg !173 + %2995 = lshr exact i32 %2994, 4, !dbg !173 + %2996 = and i32 %2995, 16383, !dbg !173 + %2997 = zext nneg i32 %2996 to i64, !dbg !173 + %2998 = or disjoint i64 %2997, 4611686293338849280, !dbg !173 + %2999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 0, !dbg !173 + %3000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 1, !dbg !173 + %3001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 2, !dbg !173 + %3002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 3, !dbg !173 + %3003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 4, !dbg !173 + %3004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 5, !dbg !173 + %3005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 6, !dbg !173 + %3006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 7, !dbg !173 + %3007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 8, !dbg !173 + %3008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 9, !dbg !173 + %3009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 10, !dbg !173 + %3010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 11, !dbg !173 + %3011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 12, !dbg !173 + %3012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 13, !dbg !173 + %3013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 14, !dbg !173 + %3014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 15, !dbg !173 + %3015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 16, !dbg !173 + %3016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 17, !dbg !173 + %3017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 18, !dbg !173 + %3018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 19, !dbg !173 + %3019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 20, !dbg !173 + %3020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 21, !dbg !173 + %3021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 22, !dbg !173 + %3022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 23, !dbg !173 + %3023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 24, !dbg !173 + %3024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 25, !dbg !173 + %3025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 26, !dbg !173 + %3026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 27, !dbg !173 + %3027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 28, !dbg !173 + %3028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 29, !dbg !173 + %3029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 30, !dbg !173 + %3030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2987, 31, !dbg !173 + %3031 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2999, float %3000, float %3001, float %3002, float %3003, float %3004, float %3005, float %3006, float %3007, float %3008, float %3009, float %3010, float %3011, float %3012, float %3013, float %3014, float %3015, float %3016, float %3017, float %3018, float %3019, float %3020, float %3021, float %3022, float %3023, float %3024, float %3025, float %3026, float %3027, float %3028, float %3029, float %3030, i64 %2993, i64 %2998, i1 true) #3, !dbg !173 + %3032 = or disjoint i32 %2756, 16480, !dbg !173 + %3033 = add i32 %3032, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !173 + %3034 = lshr exact i32 %3033, 4, !dbg !173 + %3035 = and i32 %3034, 16383, !dbg !173 + %3036 = zext nneg i32 %3035 to i64, !dbg !173 + %3037 = or disjoint i64 %3036, 4611686293372403712, !dbg !173 + %3038 = add i32 %2762, 8288, !dbg !173 + %3039 = lshr exact i32 %3038, 4, !dbg !173 + %3040 = and i32 %3039, 16383, !dbg !173 + %3041 = zext nneg i32 %3040 to i64, !dbg !173 + %3042 = or disjoint i64 %3041, 4611686293338849280, !dbg !173 + %3043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 0, !dbg !173 + %3044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 1, !dbg !173 + %3045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 2, !dbg !173 + %3046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 3, !dbg !173 + %3047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 4, !dbg !173 + %3048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 5, !dbg !173 + %3049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 6, !dbg !173 + %3050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 7, !dbg !173 + %3051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 8, !dbg !173 + %3052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 9, !dbg !173 + %3053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 10, !dbg !173 + %3054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 11, !dbg !173 + %3055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 12, !dbg !173 + %3056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 13, !dbg !173 + %3057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 14, !dbg !173 + %3058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 15, !dbg !173 + %3059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 16, !dbg !173 + %3060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 17, !dbg !173 + %3061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 18, !dbg !173 + %3062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 19, !dbg !173 + %3063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 20, !dbg !173 + %3064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 21, !dbg !173 + %3065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 22, !dbg !173 + %3066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 23, !dbg !173 + %3067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 24, !dbg !173 + %3068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 25, !dbg !173 + %3069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 26, !dbg !173 + %3070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 27, !dbg !173 + %3071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 28, !dbg !173 + %3072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 29, !dbg !173 + %3073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 30, !dbg !173 + %3074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3031, 31, !dbg !173 + %3075 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3043, float %3044, float %3045, float %3046, float %3047, float %3048, float %3049, float %3050, float %3051, float %3052, float %3053, float %3054, float %3055, float %3056, float %3057, float %3058, float %3059, float %3060, float %3061, float %3062, float %3063, float %3064, float %3065, float %3066, float %3067, float %3068, float %3069, float %3070, float %3071, float %3072, float %3073, float %3074, i64 %3037, i64 %3042, i1 true) #3, !dbg !173 + %3076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 0, !dbg !173 + %3077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 1, !dbg !173 + %3078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 2, !dbg !173 + %3079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 3, !dbg !173 + %3080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 4, !dbg !173 + %3081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 5, !dbg !173 + %3082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 6, !dbg !173 + %3083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 7, !dbg !173 + %3084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 8, !dbg !173 + %3085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 9, !dbg !173 + %3086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 10, !dbg !173 + %3087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 11, !dbg !173 + %3088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 12, !dbg !173 + %3089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 13, !dbg !173 + %3090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 14, !dbg !173 + %3091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 15, !dbg !173 + %3092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 16, !dbg !173 + %3093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 17, !dbg !173 + %3094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 18, !dbg !173 + %3095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 19, !dbg !173 + %3096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 20, !dbg !173 + %3097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 21, !dbg !173 + %3098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 22, !dbg !173 + %3099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 23, !dbg !173 + %3100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 24, !dbg !173 + %3101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 25, !dbg !173 + %3102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 26, !dbg !173 + %3103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 27, !dbg !173 + %3104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 28, !dbg !173 + %3105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 29, !dbg !173 + %3106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 30, !dbg !173 + %3107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3075, 31, !dbg !173 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !173 + %3108 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %3076, float %3077, float %3078, float %3079, float %3080, float %3081, float %3082, float %3083, float %3084, float %3085, float %3086, float %3087, float %3088, float %3089, float %3090, float %3091, float %3092, float %3093, float %3094, float %3095, float %3096, float %3097, float %3098, float %3099, float %3100, float %3101, float %3102, float %3103, float %3104, float %3105, float %3106, float %3107, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %2753, i32 0, i32 0) #3, !dbg !173 + %3109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 0, !dbg !173 + %3110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 1, !dbg !173 + %3111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 2, !dbg !173 + %3112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 3, !dbg !173 + %3113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 4, !dbg !173 + %3114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 5, !dbg !173 + %3115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 6, !dbg !173 + %3116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 7, !dbg !173 + %3117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 8, !dbg !173 + %3118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 9, !dbg !173 + %3119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 10, !dbg !173 + %3120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 11, !dbg !173 + %3121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 12, !dbg !173 + %3122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 13, !dbg !173 + %3123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 14, !dbg !173 + %3124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 15, !dbg !173 + %3125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 16, !dbg !173 + %3126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 17, !dbg !173 + %3127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 18, !dbg !173 + %3128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 19, !dbg !173 + %3129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 20, !dbg !173 + %3130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 21, !dbg !173 + %3131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 22, !dbg !173 + %3132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 23, !dbg !173 + %3133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 24, !dbg !173 + %3134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 25, !dbg !173 + %3135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 26, !dbg !173 + %3136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 27, !dbg !173 + %3137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 28, !dbg !173 + %3138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 29, !dbg !173 + %3139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 30, !dbg !173 + %3140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3108, 31, !dbg !173 + %3141 = fmul float %3109, 0x3FB6A09E60000000, !dbg !174 + %3142 = fmul float %3110, 0x3FB6A09E60000000, !dbg !174 + %3143 = fmul float %3111, 0x3FB6A09E60000000, !dbg !174 + %3144 = fmul float %3112, 0x3FB6A09E60000000, !dbg !174 + %3145 = fmul float %3113, 0x3FB6A09E60000000, !dbg !174 + %3146 = fmul float %3114, 0x3FB6A09E60000000, !dbg !174 + %3147 = fmul float %3115, 0x3FB6A09E60000000, !dbg !174 + %3148 = fmul float %3116, 0x3FB6A09E60000000, !dbg !174 + %3149 = fmul float %3117, 0x3FB6A09E60000000, !dbg !174 + %3150 = fmul float %3118, 0x3FB6A09E60000000, !dbg !174 + %3151 = fmul float %3119, 0x3FB6A09E60000000, !dbg !174 + %3152 = fmul float %3120, 0x3FB6A09E60000000, !dbg !174 + %3153 = fmul float %3121, 0x3FB6A09E60000000, !dbg !174 + %3154 = fmul float %3122, 0x3FB6A09E60000000, !dbg !174 + %3155 = fmul float %3123, 0x3FB6A09E60000000, !dbg !174 + %3156 = fmul float %3124, 0x3FB6A09E60000000, !dbg !174 + %3157 = fmul float %3125, 0x3FB6A09E60000000, !dbg !174 + %3158 = fmul float %3126, 0x3FB6A09E60000000, !dbg !174 + %3159 = fmul float %3127, 0x3FB6A09E60000000, !dbg !174 + %3160 = fmul float %3128, 0x3FB6A09E60000000, !dbg !174 + %3161 = fmul float %3129, 0x3FB6A09E60000000, !dbg !174 + %3162 = fmul float %3130, 0x3FB6A09E60000000, !dbg !174 + %3163 = fmul float %3131, 0x3FB6A09E60000000, !dbg !174 + %3164 = fmul float %3132, 0x3FB6A09E60000000, !dbg !174 + %3165 = fmul float %3133, 0x3FB6A09E60000000, !dbg !174 + %3166 = fmul float %3134, 0x3FB6A09E60000000, !dbg !174 + %3167 = fmul float %3135, 0x3FB6A09E60000000, !dbg !174 + %3168 = fmul float %3136, 0x3FB6A09E60000000, !dbg !174 + %3169 = fmul float %3137, 0x3FB6A09E60000000, !dbg !174 + %3170 = fmul float %3138, 0x3FB6A09E60000000, !dbg !174 + %3171 = fmul float %3139, 0x3FB6A09E60000000, !dbg !174 + %3172 = fmul float %3140, 0x3FB6A09E60000000, !dbg !174 + %3173 = fmul float %3141, 0x3FF7154760000000, !dbg !175 + %3174 = select i1 %2721, float %3173, float 0xFFF0000000000000, !dbg !176 + %3175 = fmul float %3142, 0x3FF7154760000000, !dbg !175 + %3176 = select i1 %2723, float %3175, float 0xFFF0000000000000, !dbg !176 + %3177 = fmul float %3143, 0x3FF7154760000000, !dbg !175 + %3178 = select i1 %2721, float %3177, float 0xFFF0000000000000, !dbg !176 + %3179 = fmul float %3144, 0x3FF7154760000000, !dbg !175 + %3180 = select i1 %2723, float %3179, float 0xFFF0000000000000, !dbg !176 + %3181 = fmul float %3145, 0x3FF7154760000000, !dbg !175 + %3182 = select i1 %2725, float %3181, float 0xFFF0000000000000, !dbg !176 + %3183 = fmul float %3146, 0x3FF7154760000000, !dbg !175 + %3184 = select i1 %2727, float %3183, float 0xFFF0000000000000, !dbg !176 + %3185 = fmul float %3147, 0x3FF7154760000000, !dbg !175 + %3186 = select i1 %2725, float %3185, float 0xFFF0000000000000, !dbg !176 + %3187 = fmul float %3148, 0x3FF7154760000000, !dbg !175 + %3188 = select i1 %2727, float %3187, float 0xFFF0000000000000, !dbg !176 + %3189 = fmul float %3149, 0x3FF7154760000000, !dbg !175 + %3190 = select i1 %2729, float %3189, float 0xFFF0000000000000, !dbg !176 + %3191 = fmul float %3150, 0x3FF7154760000000, !dbg !175 + %3192 = select i1 %2731, float %3191, float 0xFFF0000000000000, !dbg !176 + %3193 = fmul float %3151, 0x3FF7154760000000, !dbg !175 + %3194 = select i1 %2729, float %3193, float 0xFFF0000000000000, !dbg !176 + %3195 = fmul float %3152, 0x3FF7154760000000, !dbg !175 + %3196 = select i1 %2731, float %3195, float 0xFFF0000000000000, !dbg !176 + %3197 = fmul float %3153, 0x3FF7154760000000, !dbg !175 + %3198 = select i1 %2733, float %3197, float 0xFFF0000000000000, !dbg !176 + %3199 = fmul float %3154, 0x3FF7154760000000, !dbg !175 + %3200 = select i1 %2735, float %3199, float 0xFFF0000000000000, !dbg !176 + %3201 = fmul float %3155, 0x3FF7154760000000, !dbg !175 + %3202 = select i1 %2733, float %3201, float 0xFFF0000000000000, !dbg !176 + %3203 = fmul float %3156, 0x3FF7154760000000, !dbg !175 + %3204 = select i1 %2735, float %3203, float 0xFFF0000000000000, !dbg !176 + %3205 = fmul float %3157, 0x3FF7154760000000, !dbg !175 + %3206 = select i1 %2737, float %3205, float 0xFFF0000000000000, !dbg !176 + %3207 = fmul float %3158, 0x3FF7154760000000, !dbg !175 + %3208 = select i1 %2739, float %3207, float 0xFFF0000000000000, !dbg !176 + %3209 = fmul float %3159, 0x3FF7154760000000, !dbg !175 + %3210 = select i1 %2737, float %3209, float 0xFFF0000000000000, !dbg !176 + %3211 = fmul float %3160, 0x3FF7154760000000, !dbg !175 + %3212 = select i1 %2739, float %3211, float 0xFFF0000000000000, !dbg !176 + %3213 = fmul float %3161, 0x3FF7154760000000, !dbg !175 + %3214 = select i1 %2741, float %3213, float 0xFFF0000000000000, !dbg !176 + %3215 = fmul float %3162, 0x3FF7154760000000, !dbg !175 + %3216 = select i1 %2743, float %3215, float 0xFFF0000000000000, !dbg !176 + %3217 = fmul float %3163, 0x3FF7154760000000, !dbg !175 + %3218 = select i1 %2741, float %3217, float 0xFFF0000000000000, !dbg !176 + %3219 = fmul float %3164, 0x3FF7154760000000, !dbg !175 + %3220 = select i1 %2743, float %3219, float 0xFFF0000000000000, !dbg !176 + %3221 = fmul float %3165, 0x3FF7154760000000, !dbg !175 + %3222 = select i1 %2745, float %3221, float 0xFFF0000000000000, !dbg !176 + %3223 = fmul float %3166, 0x3FF7154760000000, !dbg !175 + %3224 = select i1 %2747, float %3223, float 0xFFF0000000000000, !dbg !176 + %3225 = fmul float %3167, 0x3FF7154760000000, !dbg !175 + %3226 = select i1 %2745, float %3225, float 0xFFF0000000000000, !dbg !176 + %3227 = fmul float %3168, 0x3FF7154760000000, !dbg !175 + %3228 = select i1 %2747, float %3227, float 0xFFF0000000000000, !dbg !176 + %3229 = fmul float %3169, 0x3FF7154760000000, !dbg !175 + %3230 = select i1 %2749, float %3229, float 0xFFF0000000000000, !dbg !176 + %3231 = fmul float %3170, 0x3FF7154760000000, !dbg !175 + %3232 = select i1 %2751, float %3231, float 0xFFF0000000000000, !dbg !176 + %3233 = fmul float %3171, 0x3FF7154760000000, !dbg !175 + %3234 = select i1 %2749, float %3233, float 0xFFF0000000000000, !dbg !176 + %3235 = fmul float %3172, 0x3FF7154760000000, !dbg !175 + %3236 = select i1 %2751, float %3235, float 0xFFF0000000000000, !dbg !176 + %3237 = fsub float %3174, %373, !dbg !177 + %3238 = fsub float %3176, %373, !dbg !177 + %3239 = fsub float %3178, %374, !dbg !177 + %3240 = fsub float %3180, %374, !dbg !177 + %3241 = fsub float %3182, %373, !dbg !177 + %3242 = fsub float %3184, %373, !dbg !177 + %3243 = fsub float %3186, %374, !dbg !177 + %3244 = fsub float %3188, %374, !dbg !177 + %3245 = fsub float %3190, %373, !dbg !177 + %3246 = fsub float %3192, %373, !dbg !177 + %3247 = fsub float %3194, %374, !dbg !177 + %3248 = fsub float %3196, %374, !dbg !177 + %3249 = fsub float %3198, %373, !dbg !177 + %3250 = fsub float %3200, %373, !dbg !177 + %3251 = fsub float %3202, %374, !dbg !177 + %3252 = fsub float %3204, %374, !dbg !177 + %3253 = fsub float %3206, %373, !dbg !177 + %3254 = fsub float %3208, %373, !dbg !177 + %3255 = fsub float %3210, %374, !dbg !177 + %3256 = fsub float %3212, %374, !dbg !177 + %3257 = fsub float %3214, %373, !dbg !177 + %3258 = fsub float %3216, %373, !dbg !177 + %3259 = fsub float %3218, %374, !dbg !177 + %3260 = fsub float %3220, %374, !dbg !177 + %3261 = fsub float %3222, %373, !dbg !177 + %3262 = fsub float %3224, %373, !dbg !177 + %3263 = fsub float %3226, %374, !dbg !177 + %3264 = fsub float %3228, %374, !dbg !177 + %3265 = fsub float %3230, %373, !dbg !177 + %3266 = fsub float %3232, %373, !dbg !177 + %3267 = fsub float %3234, %374, !dbg !177 + %3268 = fsub float %3236, %374, !dbg !177 + %3269 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1418 = icmp eq i32 %3269, 0, !dbg !178 + br i1 %.not.i1418, label %3272, label %3270, !dbg !178 + +3270: ; preds = %2709 + %3271 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3237) #3, !dbg !178 + br label %__nv_exp2f.exit1420, !dbg !178 + +3272: ; preds = %2709 + %3273 = tail call float @llvm.nvvm.ex2.approx.f(float %3237) #3, !dbg !178 + br label %__nv_exp2f.exit1420, !dbg !178 + +__nv_exp2f.exit1420: ; preds = %3270, %3272 + %.0.i1419 = phi float [ %3271, %3270 ], [ %3273, %3272 ], !dbg !178 + %3274 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1421 = icmp eq i32 %3274, 0, !dbg !178 + br i1 %.not.i1421, label %3277, label %3275, !dbg !178 + +3275: ; preds = %__nv_exp2f.exit1420 + %3276 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3238) #3, !dbg !178 + br label %__nv_exp2f.exit1423, !dbg !178 + +3277: ; preds = %__nv_exp2f.exit1420 + %3278 = tail call float @llvm.nvvm.ex2.approx.f(float %3238) #3, !dbg !178 + br label %__nv_exp2f.exit1423, !dbg !178 + +__nv_exp2f.exit1423: ; preds = %3275, %3277 + %.0.i1422 = phi float [ %3276, %3275 ], [ %3278, %3277 ], !dbg !178 + %3279 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1424 = icmp eq i32 %3279, 0, !dbg !178 + br i1 %.not.i1424, label %3282, label %3280, !dbg !178 + +3280: ; preds = %__nv_exp2f.exit1423 + %3281 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3239) #3, !dbg !178 + br label %__nv_exp2f.exit1426, !dbg !178 + +3282: ; preds = %__nv_exp2f.exit1423 + %3283 = tail call float @llvm.nvvm.ex2.approx.f(float %3239) #3, !dbg !178 + br label %__nv_exp2f.exit1426, !dbg !178 + +__nv_exp2f.exit1426: ; preds = %3280, %3282 + %.0.i1425 = phi float [ %3281, %3280 ], [ %3283, %3282 ], !dbg !178 + %3284 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1427 = icmp eq i32 %3284, 0, !dbg !178 + br i1 %.not.i1427, label %3287, label %3285, !dbg !178 + +3285: ; preds = %__nv_exp2f.exit1426 + %3286 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3240) #3, !dbg !178 + br label %__nv_exp2f.exit1429, !dbg !178 + +3287: ; preds = %__nv_exp2f.exit1426 + %3288 = tail call float @llvm.nvvm.ex2.approx.f(float %3240) #3, !dbg !178 + br label %__nv_exp2f.exit1429, !dbg !178 + +__nv_exp2f.exit1429: ; preds = %3285, %3287 + %.0.i1428 = phi float [ %3286, %3285 ], [ %3288, %3287 ], !dbg !178 + %3289 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1430 = icmp eq i32 %3289, 0, !dbg !178 + br i1 %.not.i1430, label %3292, label %3290, !dbg !178 + +3290: ; preds = %__nv_exp2f.exit1429 + %3291 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3241) #3, !dbg !178 + br label %__nv_exp2f.exit1432, !dbg !178 + +3292: ; preds = %__nv_exp2f.exit1429 + %3293 = tail call float @llvm.nvvm.ex2.approx.f(float %3241) #3, !dbg !178 + br label %__nv_exp2f.exit1432, !dbg !178 + +__nv_exp2f.exit1432: ; preds = %3290, %3292 + %.0.i1431 = phi float [ %3291, %3290 ], [ %3293, %3292 ], !dbg !178 + %3294 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1433 = icmp eq i32 %3294, 0, !dbg !178 + br i1 %.not.i1433, label %3297, label %3295, !dbg !178 + +3295: ; preds = %__nv_exp2f.exit1432 + %3296 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3242) #3, !dbg !178 + br label %__nv_exp2f.exit1435, !dbg !178 + +3297: ; preds = %__nv_exp2f.exit1432 + %3298 = tail call float @llvm.nvvm.ex2.approx.f(float %3242) #3, !dbg !178 + br label %__nv_exp2f.exit1435, !dbg !178 + +__nv_exp2f.exit1435: ; preds = %3295, %3297 + %.0.i1434 = phi float [ %3296, %3295 ], [ %3298, %3297 ], !dbg !178 + %3299 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1436 = icmp eq i32 %3299, 0, !dbg !178 + br i1 %.not.i1436, label %3302, label %3300, !dbg !178 + +3300: ; preds = %__nv_exp2f.exit1435 + %3301 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3243) #3, !dbg !178 + br label %__nv_exp2f.exit1438, !dbg !178 + +3302: ; preds = %__nv_exp2f.exit1435 + %3303 = tail call float @llvm.nvvm.ex2.approx.f(float %3243) #3, !dbg !178 + br label %__nv_exp2f.exit1438, !dbg !178 + +__nv_exp2f.exit1438: ; preds = %3300, %3302 + %.0.i1437 = phi float [ %3301, %3300 ], [ %3303, %3302 ], !dbg !178 + %3304 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1439 = icmp eq i32 %3304, 0, !dbg !178 + br i1 %.not.i1439, label %3307, label %3305, !dbg !178 + +3305: ; preds = %__nv_exp2f.exit1438 + %3306 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3244) #3, !dbg !178 + br label %__nv_exp2f.exit1441, !dbg !178 + +3307: ; preds = %__nv_exp2f.exit1438 + %3308 = tail call float @llvm.nvvm.ex2.approx.f(float %3244) #3, !dbg !178 + br label %__nv_exp2f.exit1441, !dbg !178 + +__nv_exp2f.exit1441: ; preds = %3305, %3307 + %.0.i1440 = phi float [ %3306, %3305 ], [ %3308, %3307 ], !dbg !178 + %3309 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1442 = icmp eq i32 %3309, 0, !dbg !178 + br i1 %.not.i1442, label %3312, label %3310, !dbg !178 + +3310: ; preds = %__nv_exp2f.exit1441 + %3311 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3245) #3, !dbg !178 + br label %__nv_exp2f.exit1444, !dbg !178 + +3312: ; preds = %__nv_exp2f.exit1441 + %3313 = tail call float @llvm.nvvm.ex2.approx.f(float %3245) #3, !dbg !178 + br label %__nv_exp2f.exit1444, !dbg !178 + +__nv_exp2f.exit1444: ; preds = %3310, %3312 + %.0.i1443 = phi float [ %3311, %3310 ], [ %3313, %3312 ], !dbg !178 + %3314 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1445 = icmp eq i32 %3314, 0, !dbg !178 + br i1 %.not.i1445, label %3317, label %3315, !dbg !178 + +3315: ; preds = %__nv_exp2f.exit1444 + %3316 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3246) #3, !dbg !178 + br label %__nv_exp2f.exit1447, !dbg !178 + +3317: ; preds = %__nv_exp2f.exit1444 + %3318 = tail call float @llvm.nvvm.ex2.approx.f(float %3246) #3, !dbg !178 + br label %__nv_exp2f.exit1447, !dbg !178 + +__nv_exp2f.exit1447: ; preds = %3315, %3317 + %.0.i1446 = phi float [ %3316, %3315 ], [ %3318, %3317 ], !dbg !178 + %3319 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1448 = icmp eq i32 %3319, 0, !dbg !178 + br i1 %.not.i1448, label %3322, label %3320, !dbg !178 + +3320: ; preds = %__nv_exp2f.exit1447 + %3321 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3247) #3, !dbg !178 + br label %__nv_exp2f.exit1450, !dbg !178 + +3322: ; preds = %__nv_exp2f.exit1447 + %3323 = tail call float @llvm.nvvm.ex2.approx.f(float %3247) #3, !dbg !178 + br label %__nv_exp2f.exit1450, !dbg !178 + +__nv_exp2f.exit1450: ; preds = %3320, %3322 + %.0.i1449 = phi float [ %3321, %3320 ], [ %3323, %3322 ], !dbg !178 + %3324 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1451 = icmp eq i32 %3324, 0, !dbg !178 + br i1 %.not.i1451, label %3327, label %3325, !dbg !178 + +3325: ; preds = %__nv_exp2f.exit1450 + %3326 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3248) #3, !dbg !178 + br label %__nv_exp2f.exit1453, !dbg !178 + +3327: ; preds = %__nv_exp2f.exit1450 + %3328 = tail call float @llvm.nvvm.ex2.approx.f(float %3248) #3, !dbg !178 + br label %__nv_exp2f.exit1453, !dbg !178 + +__nv_exp2f.exit1453: ; preds = %3325, %3327 + %.0.i1452 = phi float [ %3326, %3325 ], [ %3328, %3327 ], !dbg !178 + %3329 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1454 = icmp eq i32 %3329, 0, !dbg !178 + br i1 %.not.i1454, label %3332, label %3330, !dbg !178 + +3330: ; preds = %__nv_exp2f.exit1453 + %3331 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3249) #3, !dbg !178 + br label %__nv_exp2f.exit1456, !dbg !178 + +3332: ; preds = %__nv_exp2f.exit1453 + %3333 = tail call float @llvm.nvvm.ex2.approx.f(float %3249) #3, !dbg !178 + br label %__nv_exp2f.exit1456, !dbg !178 + +__nv_exp2f.exit1456: ; preds = %3330, %3332 + %.0.i1455 = phi float [ %3331, %3330 ], [ %3333, %3332 ], !dbg !178 + %3334 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1457 = icmp eq i32 %3334, 0, !dbg !178 + br i1 %.not.i1457, label %3337, label %3335, !dbg !178 + +3335: ; preds = %__nv_exp2f.exit1456 + %3336 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3250) #3, !dbg !178 + br label %__nv_exp2f.exit1459, !dbg !178 + +3337: ; preds = %__nv_exp2f.exit1456 + %3338 = tail call float @llvm.nvvm.ex2.approx.f(float %3250) #3, !dbg !178 + br label %__nv_exp2f.exit1459, !dbg !178 + +__nv_exp2f.exit1459: ; preds = %3335, %3337 + %.0.i1458 = phi float [ %3336, %3335 ], [ %3338, %3337 ], !dbg !178 + %3339 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1460 = icmp eq i32 %3339, 0, !dbg !178 + br i1 %.not.i1460, label %3342, label %3340, !dbg !178 + +3340: ; preds = %__nv_exp2f.exit1459 + %3341 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3251) #3, !dbg !178 + br label %__nv_exp2f.exit1462, !dbg !178 + +3342: ; preds = %__nv_exp2f.exit1459 + %3343 = tail call float @llvm.nvvm.ex2.approx.f(float %3251) #3, !dbg !178 + br label %__nv_exp2f.exit1462, !dbg !178 + +__nv_exp2f.exit1462: ; preds = %3340, %3342 + %.0.i1461 = phi float [ %3341, %3340 ], [ %3343, %3342 ], !dbg !178 + %3344 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1463 = icmp eq i32 %3344, 0, !dbg !178 + br i1 %.not.i1463, label %3347, label %3345, !dbg !178 + +3345: ; preds = %__nv_exp2f.exit1462 + %3346 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3252) #3, !dbg !178 + br label %__nv_exp2f.exit1465, !dbg !178 + +3347: ; preds = %__nv_exp2f.exit1462 + %3348 = tail call float @llvm.nvvm.ex2.approx.f(float %3252) #3, !dbg !178 + br label %__nv_exp2f.exit1465, !dbg !178 + +__nv_exp2f.exit1465: ; preds = %3345, %3347 + %.0.i1464 = phi float [ %3346, %3345 ], [ %3348, %3347 ], !dbg !178 + %3349 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1466 = icmp eq i32 %3349, 0, !dbg !178 + br i1 %.not.i1466, label %3352, label %3350, !dbg !178 + +3350: ; preds = %__nv_exp2f.exit1465 + %3351 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3253) #3, !dbg !178 + br label %__nv_exp2f.exit1468, !dbg !178 + +3352: ; preds = %__nv_exp2f.exit1465 + %3353 = tail call float @llvm.nvvm.ex2.approx.f(float %3253) #3, !dbg !178 + br label %__nv_exp2f.exit1468, !dbg !178 + +__nv_exp2f.exit1468: ; preds = %3350, %3352 + %.0.i1467 = phi float [ %3351, %3350 ], [ %3353, %3352 ], !dbg !178 + %3354 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1469 = icmp eq i32 %3354, 0, !dbg !178 + br i1 %.not.i1469, label %3357, label %3355, !dbg !178 + +3355: ; preds = %__nv_exp2f.exit1468 + %3356 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3254) #3, !dbg !178 + br label %__nv_exp2f.exit1471, !dbg !178 + +3357: ; preds = %__nv_exp2f.exit1468 + %3358 = tail call float @llvm.nvvm.ex2.approx.f(float %3254) #3, !dbg !178 + br label %__nv_exp2f.exit1471, !dbg !178 + +__nv_exp2f.exit1471: ; preds = %3355, %3357 + %.0.i1470 = phi float [ %3356, %3355 ], [ %3358, %3357 ], !dbg !178 + %3359 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1472 = icmp eq i32 %3359, 0, !dbg !178 + br i1 %.not.i1472, label %3362, label %3360, !dbg !178 + +3360: ; preds = %__nv_exp2f.exit1471 + %3361 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3255) #3, !dbg !178 + br label %__nv_exp2f.exit1474, !dbg !178 + +3362: ; preds = %__nv_exp2f.exit1471 + %3363 = tail call float @llvm.nvvm.ex2.approx.f(float %3255) #3, !dbg !178 + br label %__nv_exp2f.exit1474, !dbg !178 + +__nv_exp2f.exit1474: ; preds = %3360, %3362 + %.0.i1473 = phi float [ %3361, %3360 ], [ %3363, %3362 ], !dbg !178 + %3364 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1475 = icmp eq i32 %3364, 0, !dbg !178 + br i1 %.not.i1475, label %3367, label %3365, !dbg !178 + +3365: ; preds = %__nv_exp2f.exit1474 + %3366 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3256) #3, !dbg !178 + br label %__nv_exp2f.exit1477, !dbg !178 + +3367: ; preds = %__nv_exp2f.exit1474 + %3368 = tail call float @llvm.nvvm.ex2.approx.f(float %3256) #3, !dbg !178 + br label %__nv_exp2f.exit1477, !dbg !178 + +__nv_exp2f.exit1477: ; preds = %3365, %3367 + %.0.i1476 = phi float [ %3366, %3365 ], [ %3368, %3367 ], !dbg !178 + %3369 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1478 = icmp eq i32 %3369, 0, !dbg !178 + br i1 %.not.i1478, label %3372, label %3370, !dbg !178 + +3370: ; preds = %__nv_exp2f.exit1477 + %3371 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3257) #3, !dbg !178 + br label %__nv_exp2f.exit1480, !dbg !178 + +3372: ; preds = %__nv_exp2f.exit1477 + %3373 = tail call float @llvm.nvvm.ex2.approx.f(float %3257) #3, !dbg !178 + br label %__nv_exp2f.exit1480, !dbg !178 + +__nv_exp2f.exit1480: ; preds = %3370, %3372 + %.0.i1479 = phi float [ %3371, %3370 ], [ %3373, %3372 ], !dbg !178 + %3374 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1481 = icmp eq i32 %3374, 0, !dbg !178 + br i1 %.not.i1481, label %3377, label %3375, !dbg !178 + +3375: ; preds = %__nv_exp2f.exit1480 + %3376 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3258) #3, !dbg !178 + br label %__nv_exp2f.exit1483, !dbg !178 + +3377: ; preds = %__nv_exp2f.exit1480 + %3378 = tail call float @llvm.nvvm.ex2.approx.f(float %3258) #3, !dbg !178 + br label %__nv_exp2f.exit1483, !dbg !178 + +__nv_exp2f.exit1483: ; preds = %3375, %3377 + %.0.i1482 = phi float [ %3376, %3375 ], [ %3378, %3377 ], !dbg !178 + %3379 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1484 = icmp eq i32 %3379, 0, !dbg !178 + br i1 %.not.i1484, label %3382, label %3380, !dbg !178 + +3380: ; preds = %__nv_exp2f.exit1483 + %3381 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3259) #3, !dbg !178 + br label %__nv_exp2f.exit1486, !dbg !178 + +3382: ; preds = %__nv_exp2f.exit1483 + %3383 = tail call float @llvm.nvvm.ex2.approx.f(float %3259) #3, !dbg !178 + br label %__nv_exp2f.exit1486, !dbg !178 + +__nv_exp2f.exit1486: ; preds = %3380, %3382 + %.0.i1485 = phi float [ %3381, %3380 ], [ %3383, %3382 ], !dbg !178 + %3384 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1487 = icmp eq i32 %3384, 0, !dbg !178 + br i1 %.not.i1487, label %3387, label %3385, !dbg !178 + +3385: ; preds = %__nv_exp2f.exit1486 + %3386 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3260) #3, !dbg !178 + br label %__nv_exp2f.exit1489, !dbg !178 + +3387: ; preds = %__nv_exp2f.exit1486 + %3388 = tail call float @llvm.nvvm.ex2.approx.f(float %3260) #3, !dbg !178 + br label %__nv_exp2f.exit1489, !dbg !178 + +__nv_exp2f.exit1489: ; preds = %3385, %3387 + %.0.i1488 = phi float [ %3386, %3385 ], [ %3388, %3387 ], !dbg !178 + %3389 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1490 = icmp eq i32 %3389, 0, !dbg !178 + br i1 %.not.i1490, label %3392, label %3390, !dbg !178 + +3390: ; preds = %__nv_exp2f.exit1489 + %3391 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3261) #3, !dbg !178 + br label %__nv_exp2f.exit1492, !dbg !178 + +3392: ; preds = %__nv_exp2f.exit1489 + %3393 = tail call float @llvm.nvvm.ex2.approx.f(float %3261) #3, !dbg !178 + br label %__nv_exp2f.exit1492, !dbg !178 + +__nv_exp2f.exit1492: ; preds = %3390, %3392 + %.0.i1491 = phi float [ %3391, %3390 ], [ %3393, %3392 ], !dbg !178 + %3394 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1493 = icmp eq i32 %3394, 0, !dbg !178 + br i1 %.not.i1493, label %3397, label %3395, !dbg !178 + +3395: ; preds = %__nv_exp2f.exit1492 + %3396 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3262) #3, !dbg !178 + br label %__nv_exp2f.exit1495, !dbg !178 + +3397: ; preds = %__nv_exp2f.exit1492 + %3398 = tail call float @llvm.nvvm.ex2.approx.f(float %3262) #3, !dbg !178 + br label %__nv_exp2f.exit1495, !dbg !178 + +__nv_exp2f.exit1495: ; preds = %3395, %3397 + %.0.i1494 = phi float [ %3396, %3395 ], [ %3398, %3397 ], !dbg !178 + %3399 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1496 = icmp eq i32 %3399, 0, !dbg !178 + br i1 %.not.i1496, label %3402, label %3400, !dbg !178 + +3400: ; preds = %__nv_exp2f.exit1495 + %3401 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3263) #3, !dbg !178 + br label %__nv_exp2f.exit1498, !dbg !178 + +3402: ; preds = %__nv_exp2f.exit1495 + %3403 = tail call float @llvm.nvvm.ex2.approx.f(float %3263) #3, !dbg !178 + br label %__nv_exp2f.exit1498, !dbg !178 + +__nv_exp2f.exit1498: ; preds = %3400, %3402 + %.0.i1497 = phi float [ %3401, %3400 ], [ %3403, %3402 ], !dbg !178 + %3404 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1499 = icmp eq i32 %3404, 0, !dbg !178 + br i1 %.not.i1499, label %3407, label %3405, !dbg !178 + +3405: ; preds = %__nv_exp2f.exit1498 + %3406 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3264) #3, !dbg !178 + br label %__nv_exp2f.exit1501, !dbg !178 + +3407: ; preds = %__nv_exp2f.exit1498 + %3408 = tail call float @llvm.nvvm.ex2.approx.f(float %3264) #3, !dbg !178 + br label %__nv_exp2f.exit1501, !dbg !178 + +__nv_exp2f.exit1501: ; preds = %3405, %3407 + %.0.i1500 = phi float [ %3406, %3405 ], [ %3408, %3407 ], !dbg !178 + %3409 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1502 = icmp eq i32 %3409, 0, !dbg !178 + br i1 %.not.i1502, label %3412, label %3410, !dbg !178 + +3410: ; preds = %__nv_exp2f.exit1501 + %3411 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3265) #3, !dbg !178 + br label %__nv_exp2f.exit1504, !dbg !178 + +3412: ; preds = %__nv_exp2f.exit1501 + %3413 = tail call float @llvm.nvvm.ex2.approx.f(float %3265) #3, !dbg !178 + br label %__nv_exp2f.exit1504, !dbg !178 + +__nv_exp2f.exit1504: ; preds = %3410, %3412 + %.0.i1503 = phi float [ %3411, %3410 ], [ %3413, %3412 ], !dbg !178 + %3414 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1505 = icmp eq i32 %3414, 0, !dbg !178 + br i1 %.not.i1505, label %3417, label %3415, !dbg !178 + +3415: ; preds = %__nv_exp2f.exit1504 + %3416 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3266) #3, !dbg !178 + br label %__nv_exp2f.exit1507, !dbg !178 + +3417: ; preds = %__nv_exp2f.exit1504 + %3418 = tail call float @llvm.nvvm.ex2.approx.f(float %3266) #3, !dbg !178 + br label %__nv_exp2f.exit1507, !dbg !178 + +__nv_exp2f.exit1507: ; preds = %3415, %3417 + %.0.i1506 = phi float [ %3416, %3415 ], [ %3418, %3417 ], !dbg !178 + %3419 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1508 = icmp eq i32 %3419, 0, !dbg !178 + br i1 %.not.i1508, label %3422, label %3420, !dbg !178 + +3420: ; preds = %__nv_exp2f.exit1507 + %3421 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3267) #3, !dbg !178 + br label %__nv_exp2f.exit1510, !dbg !178 + +3422: ; preds = %__nv_exp2f.exit1507 + %3423 = tail call float @llvm.nvvm.ex2.approx.f(float %3267) #3, !dbg !178 + br label %__nv_exp2f.exit1510, !dbg !178 + +__nv_exp2f.exit1510: ; preds = %3420, %3422 + %.0.i1509 = phi float [ %3421, %3420 ], [ %3423, %3422 ], !dbg !178 + %3424 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !178 + %.not.i1511 = icmp eq i32 %3424, 0, !dbg !178 + br i1 %.not.i1511, label %3427, label %3425, !dbg !178 + +3425: ; preds = %__nv_exp2f.exit1510 + %3426 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3268) #3, !dbg !178 + br label %__nv_exp2f.exit1513, !dbg !178 + +3427: ; preds = %__nv_exp2f.exit1510 + %3428 = tail call float @llvm.nvvm.ex2.approx.f(float %3268) #3, !dbg !178 + br label %__nv_exp2f.exit1513, !dbg !178 + +__nv_exp2f.exit1513: ; preds = %3425, %3427 + %.0.i1512 = phi float [ %3426, %3425 ], [ %3428, %3427 ], !dbg !178 + %3429 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2752, !dbg !169 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !179 + %3430 = add i32 %2756, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3431 = lshr exact i32 %3430, 4, !dbg !179 + %3432 = and i32 %3431, 16383, !dbg !179 + %3433 = zext nneg i32 %3432 to i64, !dbg !179 + %3434 = or disjoint i64 %3433, 4611686293372403712, !dbg !179 + %3435 = ptrtoint ptr addrspace(3) %3429 to i32, !dbg !179 + %3436 = lshr exact i32 %3435, 4, !dbg !179 + %3437 = and i32 %3436, 16383, !dbg !179 + %3438 = zext nneg i32 %3437 to i64, !dbg !179 + %3439 = or disjoint i64 %3438, 4611686293338849280, !dbg !179 + %3440 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %3434, i64 %3439) #3, !dbg !179 + %3441 = add i32 %2768, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3442 = lshr exact i32 %3441, 4, !dbg !179 + %3443 = and i32 %3442, 16383, !dbg !179 + %3444 = zext nneg i32 %3443 to i64, !dbg !179 + %3445 = or disjoint i64 %3444, 4611686293372403712, !dbg !179 + %3446 = add i32 %3435, 32, !dbg !179 + %3447 = lshr exact i32 %3446, 4, !dbg !179 + %3448 = and i32 %3447, 16383, !dbg !179 + %3449 = zext nneg i32 %3448 to i64, !dbg !179 + %3450 = or disjoint i64 %3449, 4611686293338849280, !dbg !179 + %3451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 0, !dbg !179 + %3452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 1, !dbg !179 + %3453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 2, !dbg !179 + %3454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 3, !dbg !179 + %3455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 4, !dbg !179 + %3456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 5, !dbg !179 + %3457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 6, !dbg !179 + %3458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 7, !dbg !179 + %3459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 8, !dbg !179 + %3460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 9, !dbg !179 + %3461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 10, !dbg !179 + %3462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 11, !dbg !179 + %3463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 12, !dbg !179 + %3464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 13, !dbg !179 + %3465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 14, !dbg !179 + %3466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 15, !dbg !179 + %3467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 16, !dbg !179 + %3468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 17, !dbg !179 + %3469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 18, !dbg !179 + %3470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 19, !dbg !179 + %3471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 20, !dbg !179 + %3472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 21, !dbg !179 + %3473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 22, !dbg !179 + %3474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 23, !dbg !179 + %3475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 24, !dbg !179 + %3476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 25, !dbg !179 + %3477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 26, !dbg !179 + %3478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 27, !dbg !179 + %3479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 28, !dbg !179 + %3480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 29, !dbg !179 + %3481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 30, !dbg !179 + %3482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3440, 31, !dbg !179 + %3483 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3451, float %3452, float %3453, float %3454, float %3455, float %3456, float %3457, float %3458, float %3459, float %3460, float %3461, float %3462, float %3463, float %3464, float %3465, float %3466, float %3467, float %3468, float %3469, float %3470, float %3471, float %3472, float %3473, float %3474, float %3475, float %3476, float %3477, float %3478, float %3479, float %3480, float %3481, float %3482, i64 %3445, i64 %3450, i1 true) #3, !dbg !179 + %3484 = add i32 %2812, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3485 = lshr exact i32 %3484, 4, !dbg !179 + %3486 = and i32 %3485, 16383, !dbg !179 + %3487 = zext nneg i32 %3486 to i64, !dbg !179 + %3488 = or disjoint i64 %3487, 4611686293372403712, !dbg !179 + %3489 = add i32 %3435, 64, !dbg !179 + %3490 = lshr exact i32 %3489, 4, !dbg !179 + %3491 = and i32 %3490, 16383, !dbg !179 + %3492 = zext nneg i32 %3491 to i64, !dbg !179 + %3493 = or disjoint i64 %3492, 4611686293338849280, !dbg !179 + %3494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 0, !dbg !179 + %3495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 1, !dbg !179 + %3496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 2, !dbg !179 + %3497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 3, !dbg !179 + %3498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 4, !dbg !179 + %3499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 5, !dbg !179 + %3500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 6, !dbg !179 + %3501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 7, !dbg !179 + %3502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 8, !dbg !179 + %3503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 9, !dbg !179 + %3504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 10, !dbg !179 + %3505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 11, !dbg !179 + %3506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 12, !dbg !179 + %3507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 13, !dbg !179 + %3508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 14, !dbg !179 + %3509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 15, !dbg !179 + %3510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 16, !dbg !179 + %3511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 17, !dbg !179 + %3512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 18, !dbg !179 + %3513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 19, !dbg !179 + %3514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 20, !dbg !179 + %3515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 21, !dbg !179 + %3516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 22, !dbg !179 + %3517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 23, !dbg !179 + %3518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 24, !dbg !179 + %3519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 25, !dbg !179 + %3520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 26, !dbg !179 + %3521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 27, !dbg !179 + %3522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 28, !dbg !179 + %3523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 29, !dbg !179 + %3524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 30, !dbg !179 + %3525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3483, 31, !dbg !179 + %3526 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3494, float %3495, float %3496, float %3497, float %3498, float %3499, float %3500, float %3501, float %3502, float %3503, float %3504, float %3505, float %3506, float %3507, float %3508, float %3509, float %3510, float %3511, float %3512, float %3513, float %3514, float %3515, float %3516, float %3517, float %3518, float %3519, float %3520, float %3521, float %3522, float %3523, float %3524, float %3525, i64 %3488, i64 %3493, i1 true) #3, !dbg !179 + %3527 = add i32 %2856, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3528 = lshr exact i32 %3527, 4, !dbg !179 + %3529 = and i32 %3528, 16383, !dbg !179 + %3530 = zext nneg i32 %3529 to i64, !dbg !179 + %3531 = or disjoint i64 %3530, 4611686293372403712, !dbg !179 + %3532 = add i32 %3435, 96, !dbg !179 + %3533 = lshr exact i32 %3532, 4, !dbg !179 + %3534 = and i32 %3533, 16383, !dbg !179 + %3535 = zext nneg i32 %3534 to i64, !dbg !179 + %3536 = or disjoint i64 %3535, 4611686293338849280, !dbg !179 + %3537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 0, !dbg !179 + %3538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 1, !dbg !179 + %3539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 2, !dbg !179 + %3540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 3, !dbg !179 + %3541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 4, !dbg !179 + %3542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 5, !dbg !179 + %3543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 6, !dbg !179 + %3544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 7, !dbg !179 + %3545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 8, !dbg !179 + %3546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 9, !dbg !179 + %3547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 10, !dbg !179 + %3548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 11, !dbg !179 + %3549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 12, !dbg !179 + %3550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 13, !dbg !179 + %3551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 14, !dbg !179 + %3552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 15, !dbg !179 + %3553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 16, !dbg !179 + %3554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 17, !dbg !179 + %3555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 18, !dbg !179 + %3556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 19, !dbg !179 + %3557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 20, !dbg !179 + %3558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 21, !dbg !179 + %3559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 22, !dbg !179 + %3560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 23, !dbg !179 + %3561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 24, !dbg !179 + %3562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 25, !dbg !179 + %3563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 26, !dbg !179 + %3564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 27, !dbg !179 + %3565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 28, !dbg !179 + %3566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 29, !dbg !179 + %3567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 30, !dbg !179 + %3568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3526, 31, !dbg !179 + %3569 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3537, float %3538, float %3539, float %3540, float %3541, float %3542, float %3543, float %3544, float %3545, float %3546, float %3547, float %3548, float %3549, float %3550, float %3551, float %3552, float %3553, float %3554, float %3555, float %3556, float %3557, float %3558, float %3559, float %3560, float %3561, float %3562, float %3563, float %3564, float %3565, float %3566, float %3567, float %3568, i64 %3531, i64 %3536, i1 true) #3, !dbg !179 + %3570 = add i32 %2900, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3571 = lshr exact i32 %3570, 4, !dbg !179 + %3572 = and i32 %3571, 16383, !dbg !179 + %3573 = zext nneg i32 %3572 to i64, !dbg !179 + %3574 = or disjoint i64 %3573, 4611686293372403712, !dbg !179 + %3575 = add i32 %3435, 8192, !dbg !179 + %3576 = lshr exact i32 %3575, 4, !dbg !179 + %3577 = and i32 %3576, 16383, !dbg !179 + %3578 = zext nneg i32 %3577 to i64, !dbg !179 + %3579 = or disjoint i64 %3578, 4611686293338849280, !dbg !179 + %3580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 0, !dbg !179 + %3581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 1, !dbg !179 + %3582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 2, !dbg !179 + %3583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 3, !dbg !179 + %3584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 4, !dbg !179 + %3585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 5, !dbg !179 + %3586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 6, !dbg !179 + %3587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 7, !dbg !179 + %3588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 8, !dbg !179 + %3589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 9, !dbg !179 + %3590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 10, !dbg !179 + %3591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 11, !dbg !179 + %3592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 12, !dbg !179 + %3593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 13, !dbg !179 + %3594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 14, !dbg !179 + %3595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 15, !dbg !179 + %3596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 16, !dbg !179 + %3597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 17, !dbg !179 + %3598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 18, !dbg !179 + %3599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 19, !dbg !179 + %3600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 20, !dbg !179 + %3601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 21, !dbg !179 + %3602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 22, !dbg !179 + %3603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 23, !dbg !179 + %3604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 24, !dbg !179 + %3605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 25, !dbg !179 + %3606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 26, !dbg !179 + %3607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 27, !dbg !179 + %3608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 28, !dbg !179 + %3609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 29, !dbg !179 + %3610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 30, !dbg !179 + %3611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3569, 31, !dbg !179 + %3612 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3580, float %3581, float %3582, float %3583, float %3584, float %3585, float %3586, float %3587, float %3588, float %3589, float %3590, float %3591, float %3592, float %3593, float %3594, float %3595, float %3596, float %3597, float %3598, float %3599, float %3600, float %3601, float %3602, float %3603, float %3604, float %3605, float %3606, float %3607, float %3608, float %3609, float %3610, float %3611, i64 %3574, i64 %3579, i1 true) #3, !dbg !179 + %3613 = add i32 %2944, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3614 = lshr exact i32 %3613, 4, !dbg !179 + %3615 = and i32 %3614, 16383, !dbg !179 + %3616 = zext nneg i32 %3615 to i64, !dbg !179 + %3617 = or disjoint i64 %3616, 4611686293372403712, !dbg !179 + %3618 = add i32 %3435, 8224, !dbg !179 + %3619 = lshr exact i32 %3618, 4, !dbg !179 + %3620 = and i32 %3619, 16383, !dbg !179 + %3621 = zext nneg i32 %3620 to i64, !dbg !179 + %3622 = or disjoint i64 %3621, 4611686293338849280, !dbg !179 + %3623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 0, !dbg !179 + %3624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 1, !dbg !179 + %3625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 2, !dbg !179 + %3626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 3, !dbg !179 + %3627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 4, !dbg !179 + %3628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 5, !dbg !179 + %3629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 6, !dbg !179 + %3630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 7, !dbg !179 + %3631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 8, !dbg !179 + %3632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 9, !dbg !179 + %3633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 10, !dbg !179 + %3634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 11, !dbg !179 + %3635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 12, !dbg !179 + %3636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 13, !dbg !179 + %3637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 14, !dbg !179 + %3638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 15, !dbg !179 + %3639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 16, !dbg !179 + %3640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 17, !dbg !179 + %3641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 18, !dbg !179 + %3642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 19, !dbg !179 + %3643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 20, !dbg !179 + %3644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 21, !dbg !179 + %3645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 22, !dbg !179 + %3646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 23, !dbg !179 + %3647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 24, !dbg !179 + %3648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 25, !dbg !179 + %3649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 26, !dbg !179 + %3650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 27, !dbg !179 + %3651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 28, !dbg !179 + %3652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 29, !dbg !179 + %3653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 30, !dbg !179 + %3654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3612, 31, !dbg !179 + %3655 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3623, float %3624, float %3625, float %3626, float %3627, float %3628, float %3629, float %3630, float %3631, float %3632, float %3633, float %3634, float %3635, float %3636, float %3637, float %3638, float %3639, float %3640, float %3641, float %3642, float %3643, float %3644, float %3645, float %3646, float %3647, float %3648, float %3649, float %3650, float %3651, float %3652, float %3653, float %3654, i64 %3617, i64 %3622, i1 true) #3, !dbg !179 + %3656 = add i32 %2988, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3657 = lshr exact i32 %3656, 4, !dbg !179 + %3658 = and i32 %3657, 16383, !dbg !179 + %3659 = zext nneg i32 %3658 to i64, !dbg !179 + %3660 = or disjoint i64 %3659, 4611686293372403712, !dbg !179 + %3661 = add i32 %3435, 8256, !dbg !179 + %3662 = lshr exact i32 %3661, 4, !dbg !179 + %3663 = and i32 %3662, 16383, !dbg !179 + %3664 = zext nneg i32 %3663 to i64, !dbg !179 + %3665 = or disjoint i64 %3664, 4611686293338849280, !dbg !179 + %3666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 0, !dbg !179 + %3667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 1, !dbg !179 + %3668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 2, !dbg !179 + %3669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 3, !dbg !179 + %3670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 4, !dbg !179 + %3671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 5, !dbg !179 + %3672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 6, !dbg !179 + %3673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 7, !dbg !179 + %3674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 8, !dbg !179 + %3675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 9, !dbg !179 + %3676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 10, !dbg !179 + %3677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 11, !dbg !179 + %3678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 12, !dbg !179 + %3679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 13, !dbg !179 + %3680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 14, !dbg !179 + %3681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 15, !dbg !179 + %3682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 16, !dbg !179 + %3683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 17, !dbg !179 + %3684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 18, !dbg !179 + %3685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 19, !dbg !179 + %3686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 20, !dbg !179 + %3687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 21, !dbg !179 + %3688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 22, !dbg !179 + %3689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 23, !dbg !179 + %3690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 24, !dbg !179 + %3691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 25, !dbg !179 + %3692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 26, !dbg !179 + %3693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 27, !dbg !179 + %3694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 28, !dbg !179 + %3695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 29, !dbg !179 + %3696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 30, !dbg !179 + %3697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3655, 31, !dbg !179 + %3698 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3666, float %3667, float %3668, float %3669, float %3670, float %3671, float %3672, float %3673, float %3674, float %3675, float %3676, float %3677, float %3678, float %3679, float %3680, float %3681, float %3682, float %3683, float %3684, float %3685, float %3686, float %3687, float %3688, float %3689, float %3690, float %3691, float %3692, float %3693, float %3694, float %3695, float %3696, float %3697, i64 %3660, i64 %3665, i1 true) #3, !dbg !179 + %3699 = add i32 %3032, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !179 + %3700 = lshr exact i32 %3699, 4, !dbg !179 + %3701 = and i32 %3700, 16383, !dbg !179 + %3702 = zext nneg i32 %3701 to i64, !dbg !179 + %3703 = or disjoint i64 %3702, 4611686293372403712, !dbg !179 + %3704 = add i32 %3435, 8288, !dbg !179 + %3705 = lshr exact i32 %3704, 4, !dbg !179 + %3706 = and i32 %3705, 16383, !dbg !179 + %3707 = zext nneg i32 %3706 to i64, !dbg !179 + %3708 = or disjoint i64 %3707, 4611686293338849280, !dbg !179 + %3709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 0, !dbg !179 + %3710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 1, !dbg !179 + %3711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 2, !dbg !179 + %3712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 3, !dbg !179 + %3713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 4, !dbg !179 + %3714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 5, !dbg !179 + %3715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 6, !dbg !179 + %3716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 7, !dbg !179 + %3717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 8, !dbg !179 + %3718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 9, !dbg !179 + %3719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 10, !dbg !179 + %3720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 11, !dbg !179 + %3721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 12, !dbg !179 + %3722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 13, !dbg !179 + %3723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 14, !dbg !179 + %3724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 15, !dbg !179 + %3725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 16, !dbg !179 + %3726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 17, !dbg !179 + %3727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 18, !dbg !179 + %3728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 19, !dbg !179 + %3729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 20, !dbg !179 + %3730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 21, !dbg !179 + %3731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 22, !dbg !179 + %3732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 23, !dbg !179 + %3733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 24, !dbg !179 + %3734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 25, !dbg !179 + %3735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 26, !dbg !179 + %3736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 27, !dbg !179 + %3737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 28, !dbg !179 + %3738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 29, !dbg !179 + %3739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 30, !dbg !179 + %3740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3698, 31, !dbg !179 + %3741 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3709, float %3710, float %3711, float %3712, float %3713, float %3714, float %3715, float %3716, float %3717, float %3718, float %3719, float %3720, float %3721, float %3722, float %3723, float %3724, float %3725, float %3726, float %3727, float %3728, float %3729, float %3730, float %3731, float %3732, float %3733, float %3734, float %3735, float %3736, float %3737, float %3738, float %3739, float %3740, i64 %3703, i64 %3708, i1 true) #3, !dbg !179 + %3742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 0, !dbg !179 + %3743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 1, !dbg !179 + %3744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 2, !dbg !179 + %3745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 3, !dbg !179 + %3746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 4, !dbg !179 + %3747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 5, !dbg !179 + %3748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 6, !dbg !179 + %3749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 7, !dbg !179 + %3750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 8, !dbg !179 + %3751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 9, !dbg !179 + %3752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 10, !dbg !179 + %3753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 11, !dbg !179 + %3754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 12, !dbg !179 + %3755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 13, !dbg !179 + %3756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 14, !dbg !179 + %3757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 15, !dbg !179 + %3758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 16, !dbg !179 + %3759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 17, !dbg !179 + %3760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 18, !dbg !179 + %3761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 19, !dbg !179 + %3762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 20, !dbg !179 + %3763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 21, !dbg !179 + %3764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 22, !dbg !179 + %3765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 23, !dbg !179 + %3766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 24, !dbg !179 + %3767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 25, !dbg !179 + %3768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 26, !dbg !179 + %3769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 27, !dbg !179 + %3770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 28, !dbg !179 + %3771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 29, !dbg !179 + %3772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 30, !dbg !179 + %3773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3741, 31, !dbg !179 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !179 + %3774 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %3742, float %3743, float %3744, float %3745, float %3746, float %3747, float %3748, float %3749, float %3750, float %3751, float %3752, float %3753, float %3754, float %3755, float %3756, float %3757, float %3758, float %3759, float %3760, float %3761, float %3762, float %3763, float %3764, float %3765, float %3766, float %3767, float %3768, float %3769, float %3770, float %3771, float %3772, float %3773, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %3429, i32 0, i32 0) #3, !dbg !179 + %3775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 0, !dbg !179 + %3776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 1, !dbg !179 + %3777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 2, !dbg !179 + %3778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 3, !dbg !179 + %3779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 4, !dbg !179 + %3780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 5, !dbg !179 + %3781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 6, !dbg !179 + %3782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 7, !dbg !179 + %3783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 8, !dbg !179 + %3784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 9, !dbg !179 + %3785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 10, !dbg !179 + %3786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 11, !dbg !179 + %3787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 12, !dbg !179 + %3788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 13, !dbg !179 + %3789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 14, !dbg !179 + %3790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 15, !dbg !179 + %3791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 16, !dbg !179 + %3792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 17, !dbg !179 + %3793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 18, !dbg !179 + %3794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 19, !dbg !179 + %3795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 20, !dbg !179 + %3796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 21, !dbg !179 + %3797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 22, !dbg !179 + %3798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 23, !dbg !179 + %3799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 24, !dbg !179 + %3800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 25, !dbg !179 + %3801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 26, !dbg !179 + %3802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 27, !dbg !179 + %3803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 28, !dbg !179 + %3804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 29, !dbg !179 + %3805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 30, !dbg !179 + %3806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3774, 31, !dbg !179 + %3807 = fsub float %3775, %362, !dbg !180 + %3808 = fsub float %3776, %362, !dbg !180 + %3809 = fsub float %3777, %364, !dbg !180 + %3810 = fsub float %3778, %364, !dbg !180 + %3811 = fsub float %3779, %362, !dbg !180 + %3812 = fsub float %3780, %362, !dbg !180 + %3813 = fsub float %3781, %364, !dbg !180 + %3814 = fsub float %3782, %364, !dbg !180 + %3815 = fsub float %3783, %362, !dbg !180 + %3816 = fsub float %3784, %362, !dbg !180 + %3817 = fsub float %3785, %364, !dbg !180 + %3818 = fsub float %3786, %364, !dbg !180 + %3819 = fsub float %3787, %362, !dbg !180 + %3820 = fsub float %3788, %362, !dbg !180 + %3821 = fsub float %3789, %364, !dbg !180 + %3822 = fsub float %3790, %364, !dbg !180 + %3823 = fsub float %3791, %362, !dbg !180 + %3824 = fsub float %3792, %362, !dbg !180 + %3825 = fsub float %3793, %364, !dbg !180 + %3826 = fsub float %3794, %364, !dbg !180 + %3827 = fsub float %3795, %362, !dbg !180 + %3828 = fsub float %3796, %362, !dbg !180 + %3829 = fsub float %3797, %364, !dbg !180 + %3830 = fsub float %3798, %364, !dbg !180 + %3831 = fsub float %3799, %362, !dbg !180 + %3832 = fsub float %3800, %362, !dbg !180 + %3833 = fsub float %3801, %364, !dbg !180 + %3834 = fsub float %3802, %364, !dbg !180 + %3835 = fsub float %3803, %362, !dbg !180 + %3836 = fsub float %3804, %362, !dbg !180 + %3837 = fsub float %3805, %364, !dbg !180 + %3838 = fsub float %3806, %364, !dbg !180 + %3839 = fmul float %.0.i1419, %3807, !dbg !181 + %3840 = fmul float %.0.i1422, %3808, !dbg !181 + %3841 = fmul float %.0.i1425, %3809, !dbg !181 + %3842 = fmul float %.0.i1428, %3810, !dbg !181 + %3843 = fmul float %.0.i1431, %3811, !dbg !181 + %3844 = fmul float %.0.i1434, %3812, !dbg !181 + %3845 = fmul float %.0.i1437, %3813, !dbg !181 + %3846 = fmul float %.0.i1440, %3814, !dbg !181 + %3847 = fmul float %.0.i1443, %3815, !dbg !181 + %3848 = fmul float %.0.i1446, %3816, !dbg !181 + %3849 = fmul float %.0.i1449, %3817, !dbg !181 + %3850 = fmul float %.0.i1452, %3818, !dbg !181 + %3851 = fmul float %.0.i1455, %3819, !dbg !181 + %3852 = fmul float %.0.i1458, %3820, !dbg !181 + %3853 = fmul float %.0.i1461, %3821, !dbg !181 + %3854 = fmul float %.0.i1464, %3822, !dbg !181 + %3855 = fmul float %.0.i1467, %3823, !dbg !181 + %3856 = fmul float %.0.i1470, %3824, !dbg !181 + %3857 = fmul float %.0.i1473, %3825, !dbg !181 + %3858 = fmul float %.0.i1476, %3826, !dbg !181 + %3859 = fmul float %.0.i1479, %3827, !dbg !181 + %3860 = fmul float %.0.i1482, %3828, !dbg !181 + %3861 = fmul float %.0.i1485, %3829, !dbg !181 + %3862 = fmul float %.0.i1488, %3830, !dbg !181 + %3863 = fmul float %.0.i1491, %3831, !dbg !181 + %3864 = fmul float %.0.i1494, %3832, !dbg !181 + %3865 = fmul float %.0.i1497, %3833, !dbg !181 + %3866 = fmul float %.0.i1500, %3834, !dbg !181 + %3867 = fmul float %.0.i1503, %3835, !dbg !181 + %3868 = fmul float %.0.i1506, %3836, !dbg !181 + %3869 = fmul float %.0.i1509, %3837, !dbg !181 + %3870 = fmul float %.0.i1512, %3838, !dbg !181 + %3871 = fptrunc float %3839 to bfloat, !dbg !182 + %3872 = select i1 %2721, bfloat %3871, bfloat 0xR0000, !dbg !183 + %3873 = fptrunc float %3840 to bfloat, !dbg !182 + %3874 = select i1 %2723, bfloat %3873, bfloat 0xR0000, !dbg !183 + %3875 = fptrunc float %3841 to bfloat, !dbg !182 + %3876 = select i1 %2721, bfloat %3875, bfloat 0xR0000, !dbg !183 + %3877 = fptrunc float %3842 to bfloat, !dbg !182 + %3878 = select i1 %2723, bfloat %3877, bfloat 0xR0000, !dbg !183 + %3879 = fptrunc float %3843 to bfloat, !dbg !182 + %3880 = select i1 %2725, bfloat %3879, bfloat 0xR0000, !dbg !183 + %3881 = fptrunc float %3844 to bfloat, !dbg !182 + %3882 = select i1 %2727, bfloat %3881, bfloat 0xR0000, !dbg !183 + %3883 = fptrunc float %3845 to bfloat, !dbg !182 + %3884 = select i1 %2725, bfloat %3883, bfloat 0xR0000, !dbg !183 + %3885 = fptrunc float %3846 to bfloat, !dbg !182 + %3886 = select i1 %2727, bfloat %3885, bfloat 0xR0000, !dbg !183 + %3887 = fptrunc float %3847 to bfloat, !dbg !182 + %3888 = select i1 %2729, bfloat %3887, bfloat 0xR0000, !dbg !183 + %3889 = fptrunc float %3848 to bfloat, !dbg !182 + %3890 = select i1 %2731, bfloat %3889, bfloat 0xR0000, !dbg !183 + %3891 = fptrunc float %3849 to bfloat, !dbg !182 + %3892 = select i1 %2729, bfloat %3891, bfloat 0xR0000, !dbg !183 + %3893 = fptrunc float %3850 to bfloat, !dbg !182 + %3894 = select i1 %2731, bfloat %3893, bfloat 0xR0000, !dbg !183 + %3895 = fptrunc float %3851 to bfloat, !dbg !182 + %3896 = select i1 %2733, bfloat %3895, bfloat 0xR0000, !dbg !183 + %3897 = fptrunc float %3852 to bfloat, !dbg !182 + %3898 = select i1 %2735, bfloat %3897, bfloat 0xR0000, !dbg !183 + %3899 = fptrunc float %3853 to bfloat, !dbg !182 + %3900 = select i1 %2733, bfloat %3899, bfloat 0xR0000, !dbg !183 + %3901 = fptrunc float %3854 to bfloat, !dbg !182 + %3902 = select i1 %2735, bfloat %3901, bfloat 0xR0000, !dbg !183 + %3903 = fptrunc float %3855 to bfloat, !dbg !182 + %3904 = select i1 %2737, bfloat %3903, bfloat 0xR0000, !dbg !183 + %3905 = fptrunc float %3856 to bfloat, !dbg !182 + %3906 = select i1 %2739, bfloat %3905, bfloat 0xR0000, !dbg !183 + %3907 = fptrunc float %3857 to bfloat, !dbg !182 + %3908 = select i1 %2737, bfloat %3907, bfloat 0xR0000, !dbg !183 + %3909 = fptrunc float %3858 to bfloat, !dbg !182 + %3910 = select i1 %2739, bfloat %3909, bfloat 0xR0000, !dbg !183 + %3911 = fptrunc float %3859 to bfloat, !dbg !182 + %3912 = select i1 %2741, bfloat %3911, bfloat 0xR0000, !dbg !183 + %3913 = fptrunc float %3860 to bfloat, !dbg !182 + %3914 = select i1 %2743, bfloat %3913, bfloat 0xR0000, !dbg !183 + %3915 = fptrunc float %3861 to bfloat, !dbg !182 + %3916 = select i1 %2741, bfloat %3915, bfloat 0xR0000, !dbg !183 + %3917 = fptrunc float %3862 to bfloat, !dbg !182 + %3918 = select i1 %2743, bfloat %3917, bfloat 0xR0000, !dbg !183 + %3919 = fptrunc float %3863 to bfloat, !dbg !182 + %3920 = select i1 %2745, bfloat %3919, bfloat 0xR0000, !dbg !183 + %3921 = fptrunc float %3864 to bfloat, !dbg !182 + %3922 = select i1 %2747, bfloat %3921, bfloat 0xR0000, !dbg !183 + %3923 = fptrunc float %3865 to bfloat, !dbg !182 + %3924 = select i1 %2745, bfloat %3923, bfloat 0xR0000, !dbg !183 + %3925 = fptrunc float %3866 to bfloat, !dbg !182 + %3926 = select i1 %2747, bfloat %3925, bfloat 0xR0000, !dbg !183 + %3927 = fptrunc float %3867 to bfloat, !dbg !182 + %3928 = select i1 %2749, bfloat %3927, bfloat 0xR0000, !dbg !183 + %3929 = fptrunc float %3868 to bfloat, !dbg !182 + %3930 = select i1 %2751, bfloat %3929, bfloat 0xR0000, !dbg !183 + %3931 = fptrunc float %3869 to bfloat, !dbg !182 + %3932 = select i1 %2749, bfloat %3931, bfloat 0xR0000, !dbg !183 + %3933 = fptrunc float %3870 to bfloat, !dbg !182 + %3934 = select i1 %2751, bfloat %3933, bfloat 0xR0000, !dbg !183 + %3935 = insertelement <2 x bfloat> poison, bfloat %3872, i64 0, !dbg !184 + %3936 = insertelement <2 x bfloat> %3935, bfloat %3874, i64 1, !dbg !184 + %3937 = bitcast <2 x bfloat> %3936 to i32, !dbg !184 + %3938 = insertelement <2 x bfloat> poison, bfloat %3876, i64 0, !dbg !184 + %3939 = insertelement <2 x bfloat> %3938, bfloat %3878, i64 1, !dbg !184 + %3940 = bitcast <2 x bfloat> %3939 to i32, !dbg !184 + %3941 = insertelement <2 x bfloat> poison, bfloat %3880, i64 0, !dbg !184 + %3942 = insertelement <2 x bfloat> %3941, bfloat %3882, i64 1, !dbg !184 + %3943 = bitcast <2 x bfloat> %3942 to i32, !dbg !184 + %3944 = insertelement <2 x bfloat> poison, bfloat %3884, i64 0, !dbg !184 + %3945 = insertelement <2 x bfloat> %3944, bfloat %3886, i64 1, !dbg !184 + %3946 = bitcast <2 x bfloat> %3945 to i32, !dbg !184 + %3947 = insertelement <2 x bfloat> poison, bfloat %3888, i64 0, !dbg !184 + %3948 = insertelement <2 x bfloat> %3947, bfloat %3890, i64 1, !dbg !184 + %3949 = bitcast <2 x bfloat> %3948 to i32, !dbg !184 + %3950 = insertelement <2 x bfloat> poison, bfloat %3892, i64 0, !dbg !184 + %3951 = insertelement <2 x bfloat> %3950, bfloat %3894, i64 1, !dbg !184 + %3952 = bitcast <2 x bfloat> %3951 to i32, !dbg !184 + %3953 = insertelement <2 x bfloat> poison, bfloat %3896, i64 0, !dbg !184 + %3954 = insertelement <2 x bfloat> %3953, bfloat %3898, i64 1, !dbg !184 + %3955 = bitcast <2 x bfloat> %3954 to i32, !dbg !184 + %3956 = insertelement <2 x bfloat> poison, bfloat %3900, i64 0, !dbg !184 + %3957 = insertelement <2 x bfloat> %3956, bfloat %3902, i64 1, !dbg !184 + %3958 = bitcast <2 x bfloat> %3957 to i32, !dbg !184 + %3959 = insertelement <2 x bfloat> poison, bfloat %3904, i64 0, !dbg !184 + %3960 = insertelement <2 x bfloat> %3959, bfloat %3906, i64 1, !dbg !184 + %3961 = bitcast <2 x bfloat> %3960 to i32, !dbg !184 + %3962 = insertelement <2 x bfloat> poison, bfloat %3908, i64 0, !dbg !184 + %3963 = insertelement <2 x bfloat> %3962, bfloat %3910, i64 1, !dbg !184 + %3964 = bitcast <2 x bfloat> %3963 to i32, !dbg !184 + %3965 = insertelement <2 x bfloat> poison, bfloat %3912, i64 0, !dbg !184 + %3966 = insertelement <2 x bfloat> %3965, bfloat %3914, i64 1, !dbg !184 + %3967 = bitcast <2 x bfloat> %3966 to i32, !dbg !184 + %3968 = insertelement <2 x bfloat> poison, bfloat %3916, i64 0, !dbg !184 + %3969 = insertelement <2 x bfloat> %3968, bfloat %3918, i64 1, !dbg !184 + %3970 = bitcast <2 x bfloat> %3969 to i32, !dbg !184 + %3971 = insertelement <2 x bfloat> poison, bfloat %3920, i64 0, !dbg !184 + %3972 = insertelement <2 x bfloat> %3971, bfloat %3922, i64 1, !dbg !184 + %3973 = bitcast <2 x bfloat> %3972 to i32, !dbg !184 + %3974 = insertelement <2 x bfloat> poison, bfloat %3924, i64 0, !dbg !184 + %3975 = insertelement <2 x bfloat> %3974, bfloat %3926, i64 1, !dbg !184 + %3976 = bitcast <2 x bfloat> %3975 to i32, !dbg !184 + %3977 = insertelement <2 x bfloat> poison, bfloat %3928, i64 0, !dbg !184 + %3978 = insertelement <2 x bfloat> %3977, bfloat %3930, i64 1, !dbg !184 + %3979 = bitcast <2 x bfloat> %3978 to i32, !dbg !184 + %3980 = insertelement <2 x bfloat> poison, bfloat %3932, i64 0, !dbg !184 + %3981 = insertelement <2 x bfloat> %3980, bfloat %3934, i64 1, !dbg !184 + %3982 = bitcast <2 x bfloat> %3981 to i32, !dbg !184 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !184 + %3983 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn, float %.pn2530, float %.pn2531, float %.pn2532, float %.pn2533, float %.pn2534, float %.pn2535, float %.pn2536, float %.pn2537, float %.pn2538, float %.pn2539, float %.pn2540, float %.pn2541, float %.pn2542, float %.pn2543, float %.pn2544, float %.pn2545, float %.pn2546, float %.pn2547, float %.pn2548, float %.pn2549, float %.pn2550, float %.pn2551, float %.pn2552, float %.pn2553, float %.pn2554, float %.pn2555, float %.pn2556, float %.pn2557, float %.pn2558, float %.pn2559, float %.pn2560, float %.pn2561, float %.pn2562, float %.pn2563, float %.pn2564, float %.pn2565, float %.pn2566, float %.pn2567, float %.pn2568, float %.pn2569, float %.pn2570, float %.pn2571, float %.pn2572, float %.pn2573, float %.pn2574, float %.pn2575, float %.pn2576, float %.pn2577, float %.pn2578, float %.pn2579, float %.pn2580, float %.pn2581, float %.pn2582, float %.pn2583, float %.pn2584, float %.pn2585, float %.pn2586, float %.pn2587, float %.pn2588, float %.pn2589, float %.pn2590, float %.pn2591, float %.pn2592, i32 %3937, i32 %3940, i32 %3943, i32 %3946, i64 %2766, i1 true) #3, !dbg !184 + %3984 = add i32 %2762, 2048, !dbg !184 + %3985 = lshr exact i32 %3984, 4, !dbg !184 + %3986 = and i32 %3985, 16383, !dbg !184 + %3987 = zext nneg i32 %3986 to i64, !dbg !184 + %3988 = or disjoint i64 %3987, 4611686293338849280, !dbg !184 + %3989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 0, !dbg !184 + %3990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 1, !dbg !184 + %3991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 2, !dbg !184 + %3992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 3, !dbg !184 + %3993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 4, !dbg !184 + %3994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 5, !dbg !184 + %3995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 6, !dbg !184 + %3996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 7, !dbg !184 + %3997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 8, !dbg !184 + %3998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 9, !dbg !184 + %3999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 10, !dbg !184 + %4000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 11, !dbg !184 + %4001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 12, !dbg !184 + %4002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 13, !dbg !184 + %4003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 14, !dbg !184 + %4004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 15, !dbg !184 + %4005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 16, !dbg !184 + %4006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 17, !dbg !184 + %4007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 18, !dbg !184 + %4008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 19, !dbg !184 + %4009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 20, !dbg !184 + %4010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 21, !dbg !184 + %4011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 22, !dbg !184 + %4012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 23, !dbg !184 + %4013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 24, !dbg !184 + %4014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 25, !dbg !184 + %4015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 26, !dbg !184 + %4016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 27, !dbg !184 + %4017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 28, !dbg !184 + %4018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 29, !dbg !184 + %4019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 30, !dbg !184 + %4020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 31, !dbg !184 + %4021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 32, !dbg !184 + %4022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 33, !dbg !184 + %4023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 34, !dbg !184 + %4024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 35, !dbg !184 + %4025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 36, !dbg !184 + %4026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 37, !dbg !184 + %4027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 38, !dbg !184 + %4028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 39, !dbg !184 + %4029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 40, !dbg !184 + %4030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 41, !dbg !184 + %4031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 42, !dbg !184 + %4032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 43, !dbg !184 + %4033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 44, !dbg !184 + %4034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 45, !dbg !184 + %4035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 46, !dbg !184 + %4036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 47, !dbg !184 + %4037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 48, !dbg !184 + %4038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 49, !dbg !184 + %4039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 50, !dbg !184 + %4040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 51, !dbg !184 + %4041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 52, !dbg !184 + %4042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 53, !dbg !184 + %4043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 54, !dbg !184 + %4044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 55, !dbg !184 + %4045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 56, !dbg !184 + %4046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 57, !dbg !184 + %4047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 58, !dbg !184 + %4048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 59, !dbg !184 + %4049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 60, !dbg !184 + %4050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 61, !dbg !184 + %4051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 62, !dbg !184 + %4052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3983, 63, !dbg !184 + %4053 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %3989, float %3990, float %3991, float %3992, float %3993, float %3994, float %3995, float %3996, float %3997, float %3998, float %3999, float %4000, float %4001, float %4002, float %4003, float %4004, float %4005, float %4006, float %4007, float %4008, float %4009, float %4010, float %4011, float %4012, float %4013, float %4014, float %4015, float %4016, float %4017, float %4018, float %4019, float %4020, float %4021, float %4022, float %4023, float %4024, float %4025, float %4026, float %4027, float %4028, float %4029, float %4030, float %4031, float %4032, float %4033, float %4034, float %4035, float %4036, float %4037, float %4038, float %4039, float %4040, float %4041, float %4042, float %4043, float %4044, float %4045, float %4046, float %4047, float %4048, float %4049, float %4050, float %4051, float %4052, i32 %3949, i32 %3952, i32 %3955, i32 %3958, i64 %3988, i1 true) #3, !dbg !184 + %4054 = add i32 %2762, 4096, !dbg !184 + %4055 = lshr exact i32 %4054, 4, !dbg !184 + %4056 = and i32 %4055, 16383, !dbg !184 + %4057 = zext nneg i32 %4056 to i64, !dbg !184 + %4058 = or disjoint i64 %4057, 4611686293338849280, !dbg !184 + %4059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 0, !dbg !184 + %4060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 1, !dbg !184 + %4061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 2, !dbg !184 + %4062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 3, !dbg !184 + %4063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 4, !dbg !184 + %4064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 5, !dbg !184 + %4065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 6, !dbg !184 + %4066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 7, !dbg !184 + %4067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 8, !dbg !184 + %4068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 9, !dbg !184 + %4069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 10, !dbg !184 + %4070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 11, !dbg !184 + %4071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 12, !dbg !184 + %4072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 13, !dbg !184 + %4073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 14, !dbg !184 + %4074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 15, !dbg !184 + %4075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 16, !dbg !184 + %4076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 17, !dbg !184 + %4077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 18, !dbg !184 + %4078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 19, !dbg !184 + %4079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 20, !dbg !184 + %4080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 21, !dbg !184 + %4081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 22, !dbg !184 + %4082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 23, !dbg !184 + %4083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 24, !dbg !184 + %4084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 25, !dbg !184 + %4085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 26, !dbg !184 + %4086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 27, !dbg !184 + %4087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 28, !dbg !184 + %4088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 29, !dbg !184 + %4089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 30, !dbg !184 + %4090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 31, !dbg !184 + %4091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 32, !dbg !184 + %4092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 33, !dbg !184 + %4093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 34, !dbg !184 + %4094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 35, !dbg !184 + %4095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 36, !dbg !184 + %4096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 37, !dbg !184 + %4097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 38, !dbg !184 + %4098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 39, !dbg !184 + %4099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 40, !dbg !184 + %4100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 41, !dbg !184 + %4101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 42, !dbg !184 + %4102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 43, !dbg !184 + %4103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 44, !dbg !184 + %4104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 45, !dbg !184 + %4105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 46, !dbg !184 + %4106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 47, !dbg !184 + %4107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 48, !dbg !184 + %4108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 49, !dbg !184 + %4109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 50, !dbg !184 + %4110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 51, !dbg !184 + %4111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 52, !dbg !184 + %4112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 53, !dbg !184 + %4113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 54, !dbg !184 + %4114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 55, !dbg !184 + %4115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 56, !dbg !184 + %4116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 57, !dbg !184 + %4117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 58, !dbg !184 + %4118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 59, !dbg !184 + %4119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 60, !dbg !184 + %4120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 61, !dbg !184 + %4121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 62, !dbg !184 + %4122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4053, 63, !dbg !184 + %4123 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %4059, float %4060, float %4061, float %4062, float %4063, float %4064, float %4065, float %4066, float %4067, float %4068, float %4069, float %4070, float %4071, float %4072, float %4073, float %4074, float %4075, float %4076, float %4077, float %4078, float %4079, float %4080, float %4081, float %4082, float %4083, float %4084, float %4085, float %4086, float %4087, float %4088, float %4089, float %4090, float %4091, float %4092, float %4093, float %4094, float %4095, float %4096, float %4097, float %4098, float %4099, float %4100, float %4101, float %4102, float %4103, float %4104, float %4105, float %4106, float %4107, float %4108, float %4109, float %4110, float %4111, float %4112, float %4113, float %4114, float %4115, float %4116, float %4117, float %4118, float %4119, float %4120, float %4121, float %4122, i32 %3961, i32 %3964, i32 %3967, i32 %3970, i64 %4058, i1 true) #3, !dbg !184 + %4124 = add i32 %2762, 6144, !dbg !184 + %4125 = lshr exact i32 %4124, 4, !dbg !184 + %4126 = and i32 %4125, 16383, !dbg !184 + %4127 = zext nneg i32 %4126 to i64, !dbg !184 + %4128 = or disjoint i64 %4127, 4611686293338849280, !dbg !184 + %4129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 0, !dbg !184 + %4130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 1, !dbg !184 + %4131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 2, !dbg !184 + %4132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 3, !dbg !184 + %4133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 4, !dbg !184 + %4134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 5, !dbg !184 + %4135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 6, !dbg !184 + %4136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 7, !dbg !184 + %4137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 8, !dbg !184 + %4138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 9, !dbg !184 + %4139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 10, !dbg !184 + %4140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 11, !dbg !184 + %4141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 12, !dbg !184 + %4142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 13, !dbg !184 + %4143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 14, !dbg !184 + %4144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 15, !dbg !184 + %4145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 16, !dbg !184 + %4146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 17, !dbg !184 + %4147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 18, !dbg !184 + %4148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 19, !dbg !184 + %4149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 20, !dbg !184 + %4150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 21, !dbg !184 + %4151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 22, !dbg !184 + %4152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 23, !dbg !184 + %4153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 24, !dbg !184 + %4154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 25, !dbg !184 + %4155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 26, !dbg !184 + %4156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 27, !dbg !184 + %4157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 28, !dbg !184 + %4158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 29, !dbg !184 + %4159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 30, !dbg !184 + %4160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 31, !dbg !184 + %4161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 32, !dbg !184 + %4162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 33, !dbg !184 + %4163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 34, !dbg !184 + %4164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 35, !dbg !184 + %4165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 36, !dbg !184 + %4166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 37, !dbg !184 + %4167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 38, !dbg !184 + %4168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 39, !dbg !184 + %4169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 40, !dbg !184 + %4170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 41, !dbg !184 + %4171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 42, !dbg !184 + %4172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 43, !dbg !184 + %4173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 44, !dbg !184 + %4174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 45, !dbg !184 + %4175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 46, !dbg !184 + %4176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 47, !dbg !184 + %4177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 48, !dbg !184 + %4178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 49, !dbg !184 + %4179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 50, !dbg !184 + %4180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 51, !dbg !184 + %4181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 52, !dbg !184 + %4182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 53, !dbg !184 + %4183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 54, !dbg !184 + %4184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 55, !dbg !184 + %4185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 56, !dbg !184 + %4186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 57, !dbg !184 + %4187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 58, !dbg !184 + %4188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 59, !dbg !184 + %4189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 60, !dbg !184 + %4190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 61, !dbg !184 + %4191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 62, !dbg !184 + %4192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4123, 63, !dbg !184 + %4193 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %4129, float %4130, float %4131, float %4132, float %4133, float %4134, float %4135, float %4136, float %4137, float %4138, float %4139, float %4140, float %4141, float %4142, float %4143, float %4144, float %4145, float %4146, float %4147, float %4148, float %4149, float %4150, float %4151, float %4152, float %4153, float %4154, float %4155, float %4156, float %4157, float %4158, float %4159, float %4160, float %4161, float %4162, float %4163, float %4164, float %4165, float %4166, float %4167, float %4168, float %4169, float %4170, float %4171, float %4172, float %4173, float %4174, float %4175, float %4176, float %4177, float %4178, float %4179, float %4180, float %4181, float %4182, float %4183, float %4184, float %4185, float %4186, float %4187, float %4188, float %4189, float %4190, float %4191, float %4192, i32 %3973, i32 %3976, i32 %3979, i32 %3982, i64 %4128, i1 true) #3, !dbg !184 + %4194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 0, !dbg !184 + %4195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 1, !dbg !184 + %4196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 2, !dbg !184 + %4197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 3, !dbg !184 + %4198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 4, !dbg !184 + %4199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 5, !dbg !184 + %4200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 6, !dbg !184 + %4201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 7, !dbg !184 + %4202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 8, !dbg !184 + %4203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 9, !dbg !184 + %4204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 10, !dbg !184 + %4205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 11, !dbg !184 + %4206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 12, !dbg !184 + %4207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 13, !dbg !184 + %4208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 14, !dbg !184 + %4209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 15, !dbg !184 + %4210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 16, !dbg !184 + %4211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 17, !dbg !184 + %4212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 18, !dbg !184 + %4213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 19, !dbg !184 + %4214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 20, !dbg !184 + %4215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 21, !dbg !184 + %4216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 22, !dbg !184 + %4217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 23, !dbg !184 + %4218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 24, !dbg !184 + %4219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 25, !dbg !184 + %4220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 26, !dbg !184 + %4221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 27, !dbg !184 + %4222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 28, !dbg !184 + %4223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 29, !dbg !184 + %4224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 30, !dbg !184 + %4225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 31, !dbg !184 + %4226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 32, !dbg !184 + %4227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 33, !dbg !184 + %4228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 34, !dbg !184 + %4229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 35, !dbg !184 + %4230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 36, !dbg !184 + %4231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 37, !dbg !184 + %4232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 38, !dbg !184 + %4233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 39, !dbg !184 + %4234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 40, !dbg !184 + %4235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 41, !dbg !184 + %4236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 42, !dbg !184 + %4237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 43, !dbg !184 + %4238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 44, !dbg !184 + %4239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 45, !dbg !184 + %4240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 46, !dbg !184 + %4241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 47, !dbg !184 + %4242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 48, !dbg !184 + %4243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 49, !dbg !184 + %4244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 50, !dbg !184 + %4245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 51, !dbg !184 + %4246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 52, !dbg !184 + %4247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 53, !dbg !184 + %4248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 54, !dbg !184 + %4249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 55, !dbg !184 + %4250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 56, !dbg !184 + %4251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 57, !dbg !184 + %4252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 58, !dbg !184 + %4253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 59, !dbg !184 + %4254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 60, !dbg !184 + %4255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 61, !dbg !184 + %4256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 62, !dbg !184 + %4257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4193, 63, !dbg !184 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !184 + %4258 = insertelement <16 x i32> poison, i32 %2710, i64 0, !dbg !172 + %4259 = shufflevector <16 x i32> %4258, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !172 + %4260 = add <16 x i32> %4259, %2714, !dbg !172 + %4261 = add nuw nsw i32 %2713, 1, !dbg !167 + %4262 = lshr i32 %4261, 1, !dbg !185 + %4263 = zext nneg i32 %4262 to i64, !dbg !186 + %4264 = getelementptr i32, ptr addrspace(1) %2496, i64 %4263, !dbg !186 + %4265 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !187 + %4266 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %4264, i64 %4265, i1 %2716) #3, !dbg !187 + %4267 = add nuw nsw i32 %4262, 1, !dbg !188 + %4268 = icmp slt i32 %4267, %2500, !dbg !189 + %4269 = getelementptr i8, ptr addrspace(1) %4264, i64 4, !dbg !190 + %4270 = and i1 %2716, %4268, !dbg !167 + %4271 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !191 + %4272 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %4269, i64 %4271, i1 %4270) #3, !dbg !191 + %4273 = and i32 %2713, 1, !dbg !192 + %4274 = sub i32 %4272, %4266, !dbg !193 + %4275 = shl i32 %4274, 7, !dbg !194 + %4276 = add i32 %4275, -64, !dbg !195 + %4277 = xor i32 %4273, 1, !dbg !196 + %4278 = mul nuw nsw i32 %4276, %4277, !dbg !196 + %4279 = shl nuw nsw i32 %4273, 6, !dbg !197 + %4280 = add i32 %4278, %4279, !dbg !198 + %4281 = shl i32 %4280, 10, !dbg !199 + %4282 = sext i32 %4281 to i64, !dbg !170 + %4283 = getelementptr bfloat, ptr addrspace(1) %.pn10771643, i64 %4282, !dbg !170 + %4284 = getelementptr bfloat, ptr addrspace(1) %.pn10611644, i64 %4282, !dbg !170 + %4285 = getelementptr bfloat, ptr addrspace(1) %.pn10451645, i64 %4282, !dbg !170 + %4286 = getelementptr bfloat, ptr addrspace(1) %.pn10291646, i64 %4282, !dbg !170 + %4287 = getelementptr bfloat, ptr addrspace(1) %.pn11491651, i64 %4282, !dbg !171 + %4288 = getelementptr bfloat, ptr addrspace(1) %.pn11331652, i64 %4282, !dbg !171 + %4289 = getelementptr bfloat, ptr addrspace(1) %.pn11171653, i64 %4282, !dbg !171 + %4290 = getelementptr bfloat, ptr addrspace(1) %.pn11011654, i64 %4282, !dbg !171 + %4291 = add i32 %4280, %.pn10851647, !dbg !172 + %4292 = add i32 %4280, %.pn10831648, !dbg !172 + %4293 = add i32 %4280, %.pn10811649, !dbg !172 + %4294 = add i32 %4280, %.pn10791650, !dbg !172 + %4295 = add i32 %2712, 1, !dbg !167 + %4296 = icmp sgt i32 %4295, 2, !dbg !167 + %4297 = select i1 %4296, i32 0, i32 %4295, !dbg !167 + %4298 = icmp slt i32 %4291, %18, !dbg !168 + %4299 = icmp slt i32 %4292, %18, !dbg !168 + %4300 = icmp slt i32 %4293, %18, !dbg !168 + %4301 = icmp slt i32 %4294, %18, !dbg !168 + %4302 = shl i32 %4297, 13, !dbg !169 + %4303 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %4302, !dbg !169 + %4304 = and i1 %2715, %4298, !dbg !167 + %4305 = and i1 %2715, %4299, !dbg !167 + %4306 = and i1 %2715, %4300, !dbg !167 + %4307 = and i1 %2715, %4301, !dbg !167 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !169 + %4308 = getelementptr inbounds nuw i8, ptr addrspace(3) %4303, i32 %437, !dbg !169 + %4309 = select i1 %4304, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4308, ptr addrspace(1) %4283, i32 %4309) #3, !dbg !169 + %4310 = getelementptr inbounds nuw i8, ptr addrspace(3) %4303, i32 %440, !dbg !169 + %4311 = select i1 %4305, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4310, ptr addrspace(1) %4284, i32 %4311) #3, !dbg !169 + %4312 = getelementptr inbounds nuw i8, ptr addrspace(3) %4303, i32 %443, !dbg !169 + %4313 = select i1 %4306, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4312, ptr addrspace(1) %4285, i32 %4313) #3, !dbg !169 + %4314 = getelementptr inbounds nuw i8, ptr addrspace(3) %4303, i32 %446, !dbg !169 + %4315 = select i1 %4307, i32 16, i32 0, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4314, ptr addrspace(1) %4286, i32 %4315) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + %4316 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4302, !dbg !169 + %4317 = getelementptr inbounds nuw i8, ptr addrspace(3) %4316, i32 %437, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4317, ptr addrspace(1) %4287, i32 %4309) #3, !dbg !169 + %4318 = getelementptr inbounds nuw i8, ptr addrspace(3) %4316, i32 %440, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4318, ptr addrspace(1) %4288, i32 %4311) #3, !dbg !169 + %4319 = getelementptr inbounds nuw i8, ptr addrspace(3) %4316, i32 %443, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4319, ptr addrspace(1) %4289, i32 %4313) #3, !dbg !169 + %4320 = getelementptr inbounds nuw i8, ptr addrspace(3) %4316, i32 %446, !dbg !169 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4320, ptr addrspace(1) %4290, i32 %4315) #3, !dbg !169 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !169 + %exitcond2264.not = icmp eq i32 %4261, %smax2263, !dbg !167 + br i1 %exitcond2264.not, label %._crit_edge1673, label %2709, !dbg !167 + +._crit_edge1673: ; preds = %__nv_exp2f.exit1513, %._crit_edge + %4321 = phi float [ %2569, %._crit_edge ], [ %4194, %__nv_exp2f.exit1513 ], !dbg !92 + %4322 = phi float [ %2570, %._crit_edge ], [ %4195, %__nv_exp2f.exit1513 ], !dbg !92 + %4323 = phi float [ %2571, %._crit_edge ], [ %4196, %__nv_exp2f.exit1513 ], !dbg !92 + %4324 = phi float [ %2572, %._crit_edge ], [ %4197, %__nv_exp2f.exit1513 ], !dbg !92 + %4325 = phi float [ %2573, %._crit_edge ], [ %4198, %__nv_exp2f.exit1513 ], !dbg !92 + %4326 = phi float [ %2574, %._crit_edge ], [ %4199, %__nv_exp2f.exit1513 ], !dbg !92 + %4327 = phi float [ %2575, %._crit_edge ], [ %4200, %__nv_exp2f.exit1513 ], !dbg !92 + %4328 = phi float [ %2576, %._crit_edge ], [ %4201, %__nv_exp2f.exit1513 ], !dbg !92 + %4329 = phi float [ %2577, %._crit_edge ], [ %4202, %__nv_exp2f.exit1513 ], !dbg !92 + %4330 = phi float [ %2578, %._crit_edge ], [ %4203, %__nv_exp2f.exit1513 ], !dbg !92 + %4331 = phi float [ %2579, %._crit_edge ], [ %4204, %__nv_exp2f.exit1513 ], !dbg !92 + %4332 = phi float [ %2580, %._crit_edge ], [ %4205, %__nv_exp2f.exit1513 ], !dbg !92 + %4333 = phi float [ %2581, %._crit_edge ], [ %4206, %__nv_exp2f.exit1513 ], !dbg !92 + %4334 = phi float [ %2582, %._crit_edge ], [ %4207, %__nv_exp2f.exit1513 ], !dbg !92 + %4335 = phi float [ %2583, %._crit_edge ], [ %4208, %__nv_exp2f.exit1513 ], !dbg !92 + %4336 = phi float [ %2584, %._crit_edge ], [ %4209, %__nv_exp2f.exit1513 ], !dbg !92 + %4337 = phi float [ %2585, %._crit_edge ], [ %4210, %__nv_exp2f.exit1513 ], !dbg !92 + %4338 = phi float [ %2586, %._crit_edge ], [ %4211, %__nv_exp2f.exit1513 ], !dbg !92 + %4339 = phi float [ %2587, %._crit_edge ], [ %4212, %__nv_exp2f.exit1513 ], !dbg !92 + %4340 = phi float [ %2588, %._crit_edge ], [ %4213, %__nv_exp2f.exit1513 ], !dbg !92 + %4341 = phi float [ %2589, %._crit_edge ], [ %4214, %__nv_exp2f.exit1513 ], !dbg !92 + %4342 = phi float [ %2590, %._crit_edge ], [ %4215, %__nv_exp2f.exit1513 ], !dbg !92 + %4343 = phi float [ %2591, %._crit_edge ], [ %4216, %__nv_exp2f.exit1513 ], !dbg !92 + %4344 = phi float [ %2592, %._crit_edge ], [ %4217, %__nv_exp2f.exit1513 ], !dbg !92 + %4345 = phi float [ %2593, %._crit_edge ], [ %4218, %__nv_exp2f.exit1513 ], !dbg !92 + %4346 = phi float [ %2594, %._crit_edge ], [ %4219, %__nv_exp2f.exit1513 ], !dbg !92 + %4347 = phi float [ %2595, %._crit_edge ], [ %4220, %__nv_exp2f.exit1513 ], !dbg !92 + %4348 = phi float [ %2596, %._crit_edge ], [ %4221, %__nv_exp2f.exit1513 ], !dbg !92 + %4349 = phi float [ %2597, %._crit_edge ], [ %4222, %__nv_exp2f.exit1513 ], !dbg !92 + %4350 = phi float [ %2598, %._crit_edge ], [ %4223, %__nv_exp2f.exit1513 ], !dbg !92 + %4351 = phi float [ %2599, %._crit_edge ], [ %4224, %__nv_exp2f.exit1513 ], !dbg !92 + %4352 = phi float [ %2600, %._crit_edge ], [ %4225, %__nv_exp2f.exit1513 ], !dbg !92 + %4353 = phi float [ %2601, %._crit_edge ], [ %4226, %__nv_exp2f.exit1513 ], !dbg !92 + %4354 = phi float [ %2602, %._crit_edge ], [ %4227, %__nv_exp2f.exit1513 ], !dbg !92 + %4355 = phi float [ %2603, %._crit_edge ], [ %4228, %__nv_exp2f.exit1513 ], !dbg !92 + %4356 = phi float [ %2604, %._crit_edge ], [ %4229, %__nv_exp2f.exit1513 ], !dbg !92 + %4357 = phi float [ %2605, %._crit_edge ], [ %4230, %__nv_exp2f.exit1513 ], !dbg !92 + %4358 = phi float [ %2606, %._crit_edge ], [ %4231, %__nv_exp2f.exit1513 ], !dbg !92 + %4359 = phi float [ %2607, %._crit_edge ], [ %4232, %__nv_exp2f.exit1513 ], !dbg !92 + %4360 = phi float [ %2608, %._crit_edge ], [ %4233, %__nv_exp2f.exit1513 ], !dbg !92 + %4361 = phi float [ %2609, %._crit_edge ], [ %4234, %__nv_exp2f.exit1513 ], !dbg !92 + %4362 = phi float [ %2610, %._crit_edge ], [ %4235, %__nv_exp2f.exit1513 ], !dbg !92 + %4363 = phi float [ %2611, %._crit_edge ], [ %4236, %__nv_exp2f.exit1513 ], !dbg !92 + %4364 = phi float [ %2612, %._crit_edge ], [ %4237, %__nv_exp2f.exit1513 ], !dbg !92 + %4365 = phi float [ %2613, %._crit_edge ], [ %4238, %__nv_exp2f.exit1513 ], !dbg !92 + %4366 = phi float [ %2614, %._crit_edge ], [ %4239, %__nv_exp2f.exit1513 ], !dbg !92 + %4367 = phi float [ %2615, %._crit_edge ], [ %4240, %__nv_exp2f.exit1513 ], !dbg !92 + %4368 = phi float [ %2616, %._crit_edge ], [ %4241, %__nv_exp2f.exit1513 ], !dbg !92 + %4369 = phi float [ %2617, %._crit_edge ], [ %4242, %__nv_exp2f.exit1513 ], !dbg !92 + %4370 = phi float [ %2618, %._crit_edge ], [ %4243, %__nv_exp2f.exit1513 ], !dbg !92 + %4371 = phi float [ %2619, %._crit_edge ], [ %4244, %__nv_exp2f.exit1513 ], !dbg !92 + %4372 = phi float [ %2620, %._crit_edge ], [ %4245, %__nv_exp2f.exit1513 ], !dbg !92 + %4373 = phi float [ %2621, %._crit_edge ], [ %4246, %__nv_exp2f.exit1513 ], !dbg !92 + %4374 = phi float [ %2622, %._crit_edge ], [ %4247, %__nv_exp2f.exit1513 ], !dbg !92 + %4375 = phi float [ %2623, %._crit_edge ], [ %4248, %__nv_exp2f.exit1513 ], !dbg !92 + %4376 = phi float [ %2624, %._crit_edge ], [ %4249, %__nv_exp2f.exit1513 ], !dbg !92 + %4377 = phi float [ %2625, %._crit_edge ], [ %4250, %__nv_exp2f.exit1513 ], !dbg !92 + %4378 = phi float [ %2626, %._crit_edge ], [ %4251, %__nv_exp2f.exit1513 ], !dbg !92 + %4379 = phi float [ %2627, %._crit_edge ], [ %4252, %__nv_exp2f.exit1513 ], !dbg !92 + %4380 = phi float [ %2628, %._crit_edge ], [ %4253, %__nv_exp2f.exit1513 ], !dbg !92 + %4381 = phi float [ %2629, %._crit_edge ], [ %4254, %__nv_exp2f.exit1513 ], !dbg !92 + %4382 = phi float [ %2630, %._crit_edge ], [ %4255, %__nv_exp2f.exit1513 ], !dbg !92 + %4383 = phi float [ %2631, %._crit_edge ], [ %4256, %__nv_exp2f.exit1513 ], !dbg !92 + %4384 = phi float [ %2632, %._crit_edge ], [ %4257, %__nv_exp2f.exit1513 ], !dbg !92 + %4385 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %4321, float %4322, float %4323, float %4324, float %4325, float %4326, float %4327, float %4328, float %4329, float %4330, float %4331, float %4332, float %4333, float %4334, float %4335, float %4336, float %4337, float %4338, float %4339, float %4340, float %4341, float %4342, float %4343, float %4344, float %4345, float %4346, float %4347, float %4348, float %4349, float %4350, float %4351, float %4352, float %4353, float %4354, float %4355, float %4356, float %4357, float %4358, float %4359, float %4360, float %4361, float %4362, float %4363, float %4364, float %4365, float %4366, float %4367, float %4368, float %4369, float %4370, float %4371, float %4372, float %4373, float %4374, float %4375, float %4376, float %4377, float %4378, float %4379, float %4380, float %4381, float %4382, float %4383, float %4384) #3, !dbg !167 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !167 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !167 + %4386 = getelementptr bfloat, ptr addrspace(1) %90, i64 %112, !dbg !200 + %4387 = getelementptr bfloat, ptr addrspace(1) %90, i64 %114, !dbg !200 + %4388 = getelementptr bfloat, ptr addrspace(1) %90, i64 %116, !dbg !200 + %4389 = getelementptr bfloat, ptr addrspace(1) %90, i64 %118, !dbg !200 + %4390 = getelementptr bfloat, ptr addrspace(1) %90, i64 %120, !dbg !200 + %4391 = getelementptr bfloat, ptr addrspace(1) %90, i64 %122, !dbg !200 + %4392 = getelementptr bfloat, ptr addrspace(1) %90, i64 %124, !dbg !200 + %4393 = getelementptr bfloat, ptr addrspace(1) %90, i64 %126, !dbg !200 + %4394 = getelementptr bfloat, ptr addrspace(1) %4386, i64 %130, !dbg !201 + %4395 = getelementptr bfloat, ptr addrspace(1) %4387, i64 %130, !dbg !201 + %4396 = getelementptr bfloat, ptr addrspace(1) %4388, i64 %130, !dbg !201 + %4397 = getelementptr bfloat, ptr addrspace(1) %4389, i64 %130, !dbg !201 + %4398 = getelementptr bfloat, ptr addrspace(1) %4390, i64 %130, !dbg !201 + %4399 = getelementptr bfloat, ptr addrspace(1) %4391, i64 %130, !dbg !201 + %4400 = getelementptr bfloat, ptr addrspace(1) %4392, i64 %130, !dbg !201 + %4401 = getelementptr bfloat, ptr addrspace(1) %4393, i64 %130, !dbg !201 + %4402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 0, !dbg !202 + %4403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 1, !dbg !202 + %4404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 2, !dbg !202 + %4405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 3, !dbg !202 + %4406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 4, !dbg !202 + %4407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 5, !dbg !202 + %4408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 6, !dbg !202 + %4409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 7, !dbg !202 + %4410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 8, !dbg !202 + %4411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 9, !dbg !202 + %4412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 10, !dbg !202 + %4413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 11, !dbg !202 + %4414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 12, !dbg !202 + %4415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 13, !dbg !202 + %4416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 14, !dbg !202 + %4417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 15, !dbg !202 + %4418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 16, !dbg !202 + %4419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 17, !dbg !202 + %4420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 18, !dbg !202 + %4421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 19, !dbg !202 + %4422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 20, !dbg !202 + %4423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 21, !dbg !202 + %4424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 22, !dbg !202 + %4425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 23, !dbg !202 + %4426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 24, !dbg !202 + %4427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 25, !dbg !202 + %4428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 26, !dbg !202 + %4429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 27, !dbg !202 + %4430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 28, !dbg !202 + %4431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 29, !dbg !202 + %4432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 30, !dbg !202 + %4433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 31, !dbg !202 + %4434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 32, !dbg !202 + %4435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 33, !dbg !202 + %4436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 34, !dbg !202 + %4437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 35, !dbg !202 + %4438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 36, !dbg !202 + %4439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 37, !dbg !202 + %4440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 38, !dbg !202 + %4441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 39, !dbg !202 + %4442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 40, !dbg !202 + %4443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 41, !dbg !202 + %4444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 42, !dbg !202 + %4445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 43, !dbg !202 + %4446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 44, !dbg !202 + %4447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 45, !dbg !202 + %4448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 46, !dbg !202 + %4449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 47, !dbg !202 + %4450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 48, !dbg !202 + %4451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 49, !dbg !202 + %4452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 50, !dbg !202 + %4453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 51, !dbg !202 + %4454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 52, !dbg !202 + %4455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 53, !dbg !202 + %4456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 54, !dbg !202 + %4457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 55, !dbg !202 + %4458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 56, !dbg !202 + %4459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 57, !dbg !202 + %4460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 58, !dbg !202 + %4461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 59, !dbg !202 + %4462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 60, !dbg !202 + %4463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 61, !dbg !202 + %4464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 62, !dbg !202 + %4465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4385, 63, !dbg !202 + %4466 = insertelement <2 x float> poison, float %4402, i64 0, !dbg !202 + %4467 = insertelement <2 x float> %4466, float %4403, i64 1, !dbg !202 + %4468 = fmul <2 x float> %4467, splat (float 0x3FB6A09E60000000), !dbg !202 + %4469 = fptrunc <2 x float> %4468 to <2 x bfloat>, !dbg !203 + %4470 = insertelement <2 x float> poison, float %4404, i64 0, !dbg !202 + %4471 = insertelement <2 x float> %4470, float %4405, i64 1, !dbg !202 + %4472 = fmul <2 x float> %4471, splat (float 0x3FB6A09E60000000), !dbg !202 + %4473 = fptrunc <2 x float> %4472 to <2 x bfloat>, !dbg !203 + %4474 = insertelement <2 x float> poison, float %4406, i64 0, !dbg !202 + %4475 = insertelement <2 x float> %4474, float %4407, i64 1, !dbg !202 + %4476 = fmul <2 x float> %4475, splat (float 0x3FB6A09E60000000), !dbg !202 + %4477 = fptrunc <2 x float> %4476 to <2 x bfloat>, !dbg !203 + %4478 = insertelement <2 x float> poison, float %4408, i64 0, !dbg !202 + %4479 = insertelement <2 x float> %4478, float %4409, i64 1, !dbg !202 + %4480 = fmul <2 x float> %4479, splat (float 0x3FB6A09E60000000), !dbg !202 + %4481 = fptrunc <2 x float> %4480 to <2 x bfloat>, !dbg !203 + %4482 = insertelement <2 x float> poison, float %4410, i64 0, !dbg !202 + %4483 = insertelement <2 x float> %4482, float %4411, i64 1, !dbg !202 + %4484 = fmul <2 x float> %4483, splat (float 0x3FB6A09E60000000), !dbg !202 + %4485 = fptrunc <2 x float> %4484 to <2 x bfloat>, !dbg !203 + %4486 = insertelement <2 x float> poison, float %4412, i64 0, !dbg !202 + %4487 = insertelement <2 x float> %4486, float %4413, i64 1, !dbg !202 + %4488 = fmul <2 x float> %4487, splat (float 0x3FB6A09E60000000), !dbg !202 + %4489 = fptrunc <2 x float> %4488 to <2 x bfloat>, !dbg !203 + %4490 = insertelement <2 x float> poison, float %4414, i64 0, !dbg !202 + %4491 = insertelement <2 x float> %4490, float %4415, i64 1, !dbg !202 + %4492 = fmul <2 x float> %4491, splat (float 0x3FB6A09E60000000), !dbg !202 + %4493 = fptrunc <2 x float> %4492 to <2 x bfloat>, !dbg !203 + %4494 = insertelement <2 x float> poison, float %4416, i64 0, !dbg !202 + %4495 = insertelement <2 x float> %4494, float %4417, i64 1, !dbg !202 + %4496 = fmul <2 x float> %4495, splat (float 0x3FB6A09E60000000), !dbg !202 + %4497 = fptrunc <2 x float> %4496 to <2 x bfloat>, !dbg !203 + %4498 = insertelement <2 x float> poison, float %4418, i64 0, !dbg !202 + %4499 = insertelement <2 x float> %4498, float %4419, i64 1, !dbg !202 + %4500 = fmul <2 x float> %4499, splat (float 0x3FB6A09E60000000), !dbg !202 + %4501 = fptrunc <2 x float> %4500 to <2 x bfloat>, !dbg !203 + %4502 = insertelement <2 x float> poison, float %4420, i64 0, !dbg !202 + %4503 = insertelement <2 x float> %4502, float %4421, i64 1, !dbg !202 + %4504 = fmul <2 x float> %4503, splat (float 0x3FB6A09E60000000), !dbg !202 + %4505 = fptrunc <2 x float> %4504 to <2 x bfloat>, !dbg !203 + %4506 = insertelement <2 x float> poison, float %4422, i64 0, !dbg !202 + %4507 = insertelement <2 x float> %4506, float %4423, i64 1, !dbg !202 + %4508 = fmul <2 x float> %4507, splat (float 0x3FB6A09E60000000), !dbg !202 + %4509 = fptrunc <2 x float> %4508 to <2 x bfloat>, !dbg !203 + %4510 = insertelement <2 x float> poison, float %4424, i64 0, !dbg !202 + %4511 = insertelement <2 x float> %4510, float %4425, i64 1, !dbg !202 + %4512 = fmul <2 x float> %4511, splat (float 0x3FB6A09E60000000), !dbg !202 + %4513 = fptrunc <2 x float> %4512 to <2 x bfloat>, !dbg !203 + %4514 = insertelement <2 x float> poison, float %4426, i64 0, !dbg !202 + %4515 = insertelement <2 x float> %4514, float %4427, i64 1, !dbg !202 + %4516 = fmul <2 x float> %4515, splat (float 0x3FB6A09E60000000), !dbg !202 + %4517 = fptrunc <2 x float> %4516 to <2 x bfloat>, !dbg !203 + %4518 = insertelement <2 x float> poison, float %4428, i64 0, !dbg !202 + %4519 = insertelement <2 x float> %4518, float %4429, i64 1, !dbg !202 + %4520 = fmul <2 x float> %4519, splat (float 0x3FB6A09E60000000), !dbg !202 + %4521 = fptrunc <2 x float> %4520 to <2 x bfloat>, !dbg !203 + %4522 = insertelement <2 x float> poison, float %4430, i64 0, !dbg !202 + %4523 = insertelement <2 x float> %4522, float %4431, i64 1, !dbg !202 + %4524 = fmul <2 x float> %4523, splat (float 0x3FB6A09E60000000), !dbg !202 + %4525 = fptrunc <2 x float> %4524 to <2 x bfloat>, !dbg !203 + %4526 = insertelement <2 x float> poison, float %4432, i64 0, !dbg !202 + %4527 = insertelement <2 x float> %4526, float %4433, i64 1, !dbg !202 + %4528 = fmul <2 x float> %4527, splat (float 0x3FB6A09E60000000), !dbg !202 + %4529 = fptrunc <2 x float> %4528 to <2 x bfloat>, !dbg !203 + %4530 = insertelement <2 x float> poison, float %4434, i64 0, !dbg !202 + %4531 = insertelement <2 x float> %4530, float %4435, i64 1, !dbg !202 + %4532 = fmul <2 x float> %4531, splat (float 0x3FB6A09E60000000), !dbg !202 + %4533 = fptrunc <2 x float> %4532 to <2 x bfloat>, !dbg !203 + %4534 = insertelement <2 x float> poison, float %4436, i64 0, !dbg !202 + %4535 = insertelement <2 x float> %4534, float %4437, i64 1, !dbg !202 + %4536 = fmul <2 x float> %4535, splat (float 0x3FB6A09E60000000), !dbg !202 + %4537 = fptrunc <2 x float> %4536 to <2 x bfloat>, !dbg !203 + %4538 = insertelement <2 x float> poison, float %4438, i64 0, !dbg !202 + %4539 = insertelement <2 x float> %4538, float %4439, i64 1, !dbg !202 + %4540 = fmul <2 x float> %4539, splat (float 0x3FB6A09E60000000), !dbg !202 + %4541 = fptrunc <2 x float> %4540 to <2 x bfloat>, !dbg !203 + %4542 = insertelement <2 x float> poison, float %4440, i64 0, !dbg !202 + %4543 = insertelement <2 x float> %4542, float %4441, i64 1, !dbg !202 + %4544 = fmul <2 x float> %4543, splat (float 0x3FB6A09E60000000), !dbg !202 + %4545 = fptrunc <2 x float> %4544 to <2 x bfloat>, !dbg !203 + %4546 = insertelement <2 x float> poison, float %4442, i64 0, !dbg !202 + %4547 = insertelement <2 x float> %4546, float %4443, i64 1, !dbg !202 + %4548 = fmul <2 x float> %4547, splat (float 0x3FB6A09E60000000), !dbg !202 + %4549 = fptrunc <2 x float> %4548 to <2 x bfloat>, !dbg !203 + %4550 = insertelement <2 x float> poison, float %4444, i64 0, !dbg !202 + %4551 = insertelement <2 x float> %4550, float %4445, i64 1, !dbg !202 + %4552 = fmul <2 x float> %4551, splat (float 0x3FB6A09E60000000), !dbg !202 + %4553 = fptrunc <2 x float> %4552 to <2 x bfloat>, !dbg !203 + %4554 = insertelement <2 x float> poison, float %4446, i64 0, !dbg !202 + %4555 = insertelement <2 x float> %4554, float %4447, i64 1, !dbg !202 + %4556 = fmul <2 x float> %4555, splat (float 0x3FB6A09E60000000), !dbg !202 + %4557 = fptrunc <2 x float> %4556 to <2 x bfloat>, !dbg !203 + %4558 = insertelement <2 x float> poison, float %4448, i64 0, !dbg !202 + %4559 = insertelement <2 x float> %4558, float %4449, i64 1, !dbg !202 + %4560 = fmul <2 x float> %4559, splat (float 0x3FB6A09E60000000), !dbg !202 + %4561 = fptrunc <2 x float> %4560 to <2 x bfloat>, !dbg !203 + %4562 = insertelement <2 x float> poison, float %4450, i64 0, !dbg !202 + %4563 = insertelement <2 x float> %4562, float %4451, i64 1, !dbg !202 + %4564 = fmul <2 x float> %4563, splat (float 0x3FB6A09E60000000), !dbg !202 + %4565 = fptrunc <2 x float> %4564 to <2 x bfloat>, !dbg !203 + %4566 = insertelement <2 x float> poison, float %4452, i64 0, !dbg !202 + %4567 = insertelement <2 x float> %4566, float %4453, i64 1, !dbg !202 + %4568 = fmul <2 x float> %4567, splat (float 0x3FB6A09E60000000), !dbg !202 + %4569 = fptrunc <2 x float> %4568 to <2 x bfloat>, !dbg !203 + %4570 = insertelement <2 x float> poison, float %4454, i64 0, !dbg !202 + %4571 = insertelement <2 x float> %4570, float %4455, i64 1, !dbg !202 + %4572 = fmul <2 x float> %4571, splat (float 0x3FB6A09E60000000), !dbg !202 + %4573 = fptrunc <2 x float> %4572 to <2 x bfloat>, !dbg !203 + %4574 = insertelement <2 x float> poison, float %4456, i64 0, !dbg !202 + %4575 = insertelement <2 x float> %4574, float %4457, i64 1, !dbg !202 + %4576 = fmul <2 x float> %4575, splat (float 0x3FB6A09E60000000), !dbg !202 + %4577 = fptrunc <2 x float> %4576 to <2 x bfloat>, !dbg !203 + %4578 = insertelement <2 x float> poison, float %4458, i64 0, !dbg !202 + %4579 = insertelement <2 x float> %4578, float %4459, i64 1, !dbg !202 + %4580 = fmul <2 x float> %4579, splat (float 0x3FB6A09E60000000), !dbg !202 + %4581 = fptrunc <2 x float> %4580 to <2 x bfloat>, !dbg !203 + %4582 = insertelement <2 x float> poison, float %4460, i64 0, !dbg !202 + %4583 = insertelement <2 x float> %4582, float %4461, i64 1, !dbg !202 + %4584 = fmul <2 x float> %4583, splat (float 0x3FB6A09E60000000), !dbg !202 + %4585 = fptrunc <2 x float> %4584 to <2 x bfloat>, !dbg !203 + %4586 = insertelement <2 x float> poison, float %4462, i64 0, !dbg !202 + %4587 = insertelement <2 x float> %4586, float %4463, i64 1, !dbg !202 + %4588 = fmul <2 x float> %4587, splat (float 0x3FB6A09E60000000), !dbg !202 + %4589 = fptrunc <2 x float> %4588 to <2 x bfloat>, !dbg !203 + %4590 = insertelement <2 x float> poison, float %4464, i64 0, !dbg !202 + %4591 = insertelement <2 x float> %4590, float %4465, i64 1, !dbg !202 + %4592 = fmul <2 x float> %4591, splat (float 0x3FB6A09E60000000), !dbg !202 + %4593 = fptrunc <2 x float> %4592 to <2 x bfloat>, !dbg !203 + %4594 = shl nuw nsw i32 %382, 13, !dbg !203 + %4595 = shl nuw nsw i32 %50, 5, !dbg !203 + %4596 = and i32 %4595, 7264, !dbg !203 + %4597 = and i32 %50, 24, !dbg !203 + %4598 = shl nuw nsw i32 %4597, 4, !dbg !203 + %4599 = shl nuw nsw i32 %50, 2, !dbg !203 + %4600 = and i32 %4599, 16, !dbg !203 + %4601 = or disjoint i32 %4594, %4600, !dbg !203 + %4602 = or disjoint i32 %4596, %4598, !dbg !203 + %4603 = or disjoint i32 %4601, %4602, !dbg !203 + %4604 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4603, !dbg !203 + %4605 = bitcast <2 x bfloat> %4469 to i32, !dbg !203 + %4606 = bitcast <2 x bfloat> %4477 to i32, !dbg !203 + %4607 = bitcast <2 x bfloat> %4485 to i32, !dbg !203 + %4608 = bitcast <2 x bfloat> %4493 to i32, !dbg !203 + %4609 = insertelement <4 x i32> poison, i32 %4605, i64 0, !dbg !203 + %4610 = insertelement <4 x i32> %4609, i32 %4606, i64 1, !dbg !203 + %4611 = insertelement <4 x i32> %4610, i32 %4607, i64 2, !dbg !203 + %4612 = insertelement <4 x i32> %4611, i32 %4608, i64 3, !dbg !203 + store <4 x i32> %4612, ptr addrspace(3) %4604, align 16, !dbg !203 + %4613 = getelementptr inbounds nuw i8, ptr addrspace(3) %4604, i32 512, !dbg !203 + %4614 = bitcast <2 x bfloat> %4473 to i32, !dbg !203 + %4615 = bitcast <2 x bfloat> %4481 to i32, !dbg !203 + %4616 = bitcast <2 x bfloat> %4489 to i32, !dbg !203 + %4617 = bitcast <2 x bfloat> %4497 to i32, !dbg !203 + %4618 = insertelement <4 x i32> poison, i32 %4614, i64 0, !dbg !203 + %4619 = insertelement <4 x i32> %4618, i32 %4615, i64 1, !dbg !203 + %4620 = insertelement <4 x i32> %4619, i32 %4616, i64 2, !dbg !203 + %4621 = insertelement <4 x i32> %4620, i32 %4617, i64 3, !dbg !203 + store <4 x i32> %4621, ptr addrspace(3) %4613, align 16, !dbg !203 + %4622 = xor i32 %4603, 32, !dbg !203 + %4623 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4622, !dbg !203 + %4624 = bitcast <2 x bfloat> %4501 to i32, !dbg !203 + %4625 = bitcast <2 x bfloat> %4509 to i32, !dbg !203 + %4626 = bitcast <2 x bfloat> %4517 to i32, !dbg !203 + %4627 = bitcast <2 x bfloat> %4525 to i32, !dbg !203 + %4628 = insertelement <4 x i32> poison, i32 %4624, i64 0, !dbg !203 + %4629 = insertelement <4 x i32> %4628, i32 %4625, i64 1, !dbg !203 + %4630 = insertelement <4 x i32> %4629, i32 %4626, i64 2, !dbg !203 + %4631 = insertelement <4 x i32> %4630, i32 %4627, i64 3, !dbg !203 + store <4 x i32> %4631, ptr addrspace(3) %4623, align 16, !dbg !203 + %4632 = getelementptr inbounds nuw i8, ptr addrspace(3) %4623, i32 512, !dbg !203 + %4633 = bitcast <2 x bfloat> %4505 to i32, !dbg !203 + %4634 = bitcast <2 x bfloat> %4513 to i32, !dbg !203 + %4635 = bitcast <2 x bfloat> %4521 to i32, !dbg !203 + %4636 = bitcast <2 x bfloat> %4529 to i32, !dbg !203 + %4637 = insertelement <4 x i32> poison, i32 %4633, i64 0, !dbg !203 + %4638 = insertelement <4 x i32> %4637, i32 %4634, i64 1, !dbg !203 + %4639 = insertelement <4 x i32> %4638, i32 %4635, i64 2, !dbg !203 + %4640 = insertelement <4 x i32> %4639, i32 %4636, i64 3, !dbg !203 + store <4 x i32> %4640, ptr addrspace(3) %4632, align 16, !dbg !203 + %4641 = xor i32 %4603, 64, !dbg !203 + %4642 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4641, !dbg !203 + %4643 = bitcast <2 x bfloat> %4533 to i32, !dbg !203 + %4644 = bitcast <2 x bfloat> %4541 to i32, !dbg !203 + %4645 = bitcast <2 x bfloat> %4549 to i32, !dbg !203 + %4646 = bitcast <2 x bfloat> %4557 to i32, !dbg !203 + %4647 = insertelement <4 x i32> poison, i32 %4643, i64 0, !dbg !203 + %4648 = insertelement <4 x i32> %4647, i32 %4644, i64 1, !dbg !203 + %4649 = insertelement <4 x i32> %4648, i32 %4645, i64 2, !dbg !203 + %4650 = insertelement <4 x i32> %4649, i32 %4646, i64 3, !dbg !203 + store <4 x i32> %4650, ptr addrspace(3) %4642, align 16, !dbg !203 + %4651 = getelementptr inbounds nuw i8, ptr addrspace(3) %4642, i32 512, !dbg !203 + %4652 = bitcast <2 x bfloat> %4537 to i32, !dbg !203 + %4653 = bitcast <2 x bfloat> %4545 to i32, !dbg !203 + %4654 = bitcast <2 x bfloat> %4553 to i32, !dbg !203 + %4655 = bitcast <2 x bfloat> %4561 to i32, !dbg !203 + %4656 = insertelement <4 x i32> poison, i32 %4652, i64 0, !dbg !203 + %4657 = insertelement <4 x i32> %4656, i32 %4653, i64 1, !dbg !203 + %4658 = insertelement <4 x i32> %4657, i32 %4654, i64 2, !dbg !203 + %4659 = insertelement <4 x i32> %4658, i32 %4655, i64 3, !dbg !203 + store <4 x i32> %4659, ptr addrspace(3) %4651, align 16, !dbg !203 + %4660 = xor i32 %4603, 96, !dbg !203 + %4661 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4660, !dbg !203 + %4662 = bitcast <2 x bfloat> %4565 to i32, !dbg !203 + %4663 = bitcast <2 x bfloat> %4573 to i32, !dbg !203 + %4664 = bitcast <2 x bfloat> %4581 to i32, !dbg !203 + %4665 = bitcast <2 x bfloat> %4589 to i32, !dbg !203 + %4666 = insertelement <4 x i32> poison, i32 %4662, i64 0, !dbg !203 + %4667 = insertelement <4 x i32> %4666, i32 %4663, i64 1, !dbg !203 + %4668 = insertelement <4 x i32> %4667, i32 %4664, i64 2, !dbg !203 + %4669 = insertelement <4 x i32> %4668, i32 %4665, i64 3, !dbg !203 + store <4 x i32> %4669, ptr addrspace(3) %4661, align 16, !dbg !203 + %4670 = getelementptr inbounds nuw i8, ptr addrspace(3) %4661, i32 512, !dbg !203 + %4671 = bitcast <2 x bfloat> %4569 to i32, !dbg !203 + %4672 = bitcast <2 x bfloat> %4577 to i32, !dbg !203 + %4673 = bitcast <2 x bfloat> %4585 to i32, !dbg !203 + %4674 = bitcast <2 x bfloat> %4593 to i32, !dbg !203 + %4675 = insertelement <4 x i32> poison, i32 %4671, i64 0, !dbg !203 + %4676 = insertelement <4 x i32> %4675, i32 %4672, i64 1, !dbg !203 + %4677 = insertelement <4 x i32> %4676, i32 %4673, i64 2, !dbg !203 + %4678 = insertelement <4 x i32> %4677, i32 %4674, i64 3, !dbg !203 + store <4 x i32> %4678, ptr addrspace(3) %4670, align 16, !dbg !203 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !203 + %4679 = shl nuw nsw i32 %4597, 10, !dbg !203 + %4680 = shl nuw nsw i32 %382, 5, !dbg !203 + %4681 = and i32 %4599, 1008, !dbg !203 + %4682 = or disjoint i32 %4679, %4680, !dbg !203 + %4683 = xor i32 %4682, %4681, !dbg !203 + %4684 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4683, !dbg !203 + %4685 = ptrtoint ptr addrspace(3) %4684 to i32, !dbg !203 + %4686 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4685) #3, !dbg !203 + %4687 = extractvalue { i32, i32, i32, i32 } %4686, 0, !dbg !203 + %4688 = extractvalue { i32, i32, i32, i32 } %4686, 1, !dbg !203 + %4689 = extractvalue { i32, i32, i32, i32 } %4686, 2, !dbg !203 + %4690 = extractvalue { i32, i32, i32, i32 } %4686, 3, !dbg !203 + %4691 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 1024, !dbg !203 + %4692 = ptrtoint ptr addrspace(3) %4691 to i32, !dbg !203 + %4693 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4692) #3, !dbg !203 + %4694 = extractvalue { i32, i32, i32, i32 } %4693, 0, !dbg !203 + %4695 = extractvalue { i32, i32, i32, i32 } %4693, 1, !dbg !203 + %4696 = extractvalue { i32, i32, i32, i32 } %4693, 2, !dbg !203 + %4697 = extractvalue { i32, i32, i32, i32 } %4693, 3, !dbg !203 + %4698 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 2048, !dbg !203 + %4699 = ptrtoint ptr addrspace(3) %4698 to i32, !dbg !203 + %4700 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4699) #3, !dbg !203 + %4701 = extractvalue { i32, i32, i32, i32 } %4700, 0, !dbg !203 + %4702 = extractvalue { i32, i32, i32, i32 } %4700, 1, !dbg !203 + %4703 = extractvalue { i32, i32, i32, i32 } %4700, 2, !dbg !203 + %4704 = extractvalue { i32, i32, i32, i32 } %4700, 3, !dbg !203 + %4705 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 3072, !dbg !203 + %4706 = ptrtoint ptr addrspace(3) %4705 to i32, !dbg !203 + %4707 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4706) #3, !dbg !203 + %4708 = extractvalue { i32, i32, i32, i32 } %4707, 0, !dbg !203 + %4709 = extractvalue { i32, i32, i32, i32 } %4707, 1, !dbg !203 + %4710 = extractvalue { i32, i32, i32, i32 } %4707, 2, !dbg !203 + %4711 = extractvalue { i32, i32, i32, i32 } %4707, 3, !dbg !203 + %4712 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 4096, !dbg !203 + %4713 = ptrtoint ptr addrspace(3) %4712 to i32, !dbg !203 + %4714 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4713) #3, !dbg !203 + %4715 = extractvalue { i32, i32, i32, i32 } %4714, 0, !dbg !203 + %4716 = extractvalue { i32, i32, i32, i32 } %4714, 1, !dbg !203 + %4717 = extractvalue { i32, i32, i32, i32 } %4714, 2, !dbg !203 + %4718 = extractvalue { i32, i32, i32, i32 } %4714, 3, !dbg !203 + %4719 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 5120, !dbg !203 + %4720 = ptrtoint ptr addrspace(3) %4719 to i32, !dbg !203 + %4721 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4720) #3, !dbg !203 + %4722 = extractvalue { i32, i32, i32, i32 } %4721, 0, !dbg !203 + %4723 = extractvalue { i32, i32, i32, i32 } %4721, 1, !dbg !203 + %4724 = extractvalue { i32, i32, i32, i32 } %4721, 2, !dbg !203 + %4725 = extractvalue { i32, i32, i32, i32 } %4721, 3, !dbg !203 + %4726 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 6144, !dbg !203 + %4727 = ptrtoint ptr addrspace(3) %4726 to i32, !dbg !203 + %4728 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4727) #3, !dbg !203 + %4729 = extractvalue { i32, i32, i32, i32 } %4728, 0, !dbg !203 + %4730 = extractvalue { i32, i32, i32, i32 } %4728, 1, !dbg !203 + %4731 = extractvalue { i32, i32, i32, i32 } %4728, 2, !dbg !203 + %4732 = extractvalue { i32, i32, i32, i32 } %4728, 3, !dbg !203 + %4733 = getelementptr inbounds nuw i8, ptr addrspace(3) %4684, i32 7168, !dbg !203 + %4734 = ptrtoint ptr addrspace(3) %4733 to i32, !dbg !203 + %4735 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4734) #3, !dbg !203 + %4736 = extractvalue { i32, i32, i32, i32 } %4735, 0, !dbg !203 + %4737 = extractvalue { i32, i32, i32, i32 } %4735, 1, !dbg !203 + %4738 = extractvalue { i32, i32, i32, i32 } %4735, 2, !dbg !203 + %4739 = extractvalue { i32, i32, i32, i32 } %4735, 3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4687, i32 %4688, i32 %4689, i32 %4690, ptr addrspace(1) %4394, i1 %139) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4694, i32 %4695, i32 %4696, i32 %4697, ptr addrspace(1) %4395, i1 %140) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4701, i32 %4702, i32 %4703, i32 %4704, ptr addrspace(1) %4396, i1 %141) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4708, i32 %4709, i32 %4710, i32 %4711, ptr addrspace(1) %4397, i1 %142) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4715, i32 %4716, i32 %4717, i32 %4718, ptr addrspace(1) %4398, i1 %143) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4722, i32 %4723, i32 %4724, i32 %4725, ptr addrspace(1) %4399, i1 %144) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4729, i32 %4730, i32 %4731, i32 %4732, ptr addrspace(1) %4400, i1 %145) #3, !dbg !203 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4736, i32 %4737, i32 %4738, i32 %4739, ptr addrspace(1) %4401, i1 %146) #3, !dbg !203 + br label %11625, !dbg !35 + +4740: ; preds = %27 + %4741 = shl nuw nsw i32 %36, 7, !dbg !204 + %4742 = or disjoint i32 %53, %4741, !dbg !205 + %4743 = or disjoint i32 %54, %4741, !dbg !205 + %4744 = or disjoint i32 %55, %4741, !dbg !205 + %4745 = or disjoint i32 %56, %4741, !dbg !205 + %4746 = or disjoint i32 %57, %4741, !dbg !205 + %4747 = or disjoint i32 %58, %4741, !dbg !205 + %4748 = or disjoint i32 %59, %4741, !dbg !205 + %4749 = or disjoint i32 %60, %4741, !dbg !205 + %4750 = shl i32 %4742, 10, !dbg !206 + %4751 = shl i32 %4743, 10, !dbg !206 + %4752 = shl i32 %4744, 10, !dbg !206 + %4753 = shl i32 %4745, 10, !dbg !206 + %4754 = shl i32 %4746, 10, !dbg !206 + %4755 = shl i32 %4747, 10, !dbg !206 + %4756 = shl i32 %4748, 10, !dbg !206 + %4757 = shl i32 %4749, 10, !dbg !206 + %4758 = sext i32 %4750 to i64, !dbg !208 + %4759 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4758, !dbg !208 + %4760 = sext i32 %4751 to i64, !dbg !208 + %4761 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4760, !dbg !208 + %4762 = sext i32 %4752 to i64, !dbg !208 + %4763 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4762, !dbg !208 + %4764 = sext i32 %4753 to i64, !dbg !208 + %4765 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4764, !dbg !208 + %4766 = sext i32 %4754 to i64, !dbg !208 + %4767 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4766, !dbg !208 + %4768 = sext i32 %4755 to i64, !dbg !208 + %4769 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4768, !dbg !208 + %4770 = sext i32 %4756 to i64, !dbg !208 + %4771 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4770, !dbg !208 + %4772 = sext i32 %4757 to i64, !dbg !208 + %4773 = getelementptr bfloat, ptr addrspace(1) %47, i64 %4772, !dbg !208 + %4774 = shl nuw nsw i32 %50, 3, !dbg !209 + %4775 = and i32 %4774, 120, !dbg !209 + %4776 = zext nneg i32 %4775 to i64, !dbg !210 + %4777 = getelementptr bfloat, ptr addrspace(1) %4759, i64 %4776, !dbg !210 + %4778 = getelementptr bfloat, ptr addrspace(1) %4761, i64 %4776, !dbg !210 + %4779 = getelementptr bfloat, ptr addrspace(1) %4763, i64 %4776, !dbg !210 + %4780 = getelementptr bfloat, ptr addrspace(1) %4765, i64 %4776, !dbg !210 + %4781 = getelementptr bfloat, ptr addrspace(1) %4767, i64 %4776, !dbg !210 + %4782 = getelementptr bfloat, ptr addrspace(1) %4769, i64 %4776, !dbg !210 + %4783 = getelementptr bfloat, ptr addrspace(1) %4771, i64 %4776, !dbg !210 + %4784 = getelementptr bfloat, ptr addrspace(1) %4773, i64 %4776, !dbg !210 + %4785 = icmp slt i32 %4742, %18, !dbg !211 + %4786 = icmp slt i32 %4743, %18, !dbg !211 + %4787 = icmp slt i32 %4744, %18, !dbg !211 + %4788 = icmp slt i32 %4745, %18, !dbg !211 + %4789 = icmp slt i32 %4746, %18, !dbg !211 + %4790 = icmp slt i32 %4747, %18, !dbg !211 + %4791 = icmp slt i32 %4748, %18, !dbg !211 + %4792 = icmp slt i32 %4749, %18, !dbg !211 + %4793 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4777, i1 %4785) #3, !dbg !212 + %4794 = extractvalue { i32, i32, i32, i32 } %4793, 0, !dbg !212 + %4795 = extractvalue { i32, i32, i32, i32 } %4793, 1, !dbg !212 + %4796 = extractvalue { i32, i32, i32, i32 } %4793, 2, !dbg !212 + %4797 = extractvalue { i32, i32, i32, i32 } %4793, 3, !dbg !212 + %4798 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4778, i1 %4786) #3, !dbg !212 + %4799 = extractvalue { i32, i32, i32, i32 } %4798, 0, !dbg !212 + %4800 = extractvalue { i32, i32, i32, i32 } %4798, 1, !dbg !212 + %4801 = extractvalue { i32, i32, i32, i32 } %4798, 2, !dbg !212 + %4802 = extractvalue { i32, i32, i32, i32 } %4798, 3, !dbg !212 + %4803 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4779, i1 %4787) #3, !dbg !212 + %4804 = extractvalue { i32, i32, i32, i32 } %4803, 0, !dbg !212 + %4805 = extractvalue { i32, i32, i32, i32 } %4803, 1, !dbg !212 + %4806 = extractvalue { i32, i32, i32, i32 } %4803, 2, !dbg !212 + %4807 = extractvalue { i32, i32, i32, i32 } %4803, 3, !dbg !212 + %4808 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4780, i1 %4788) #3, !dbg !212 + %4809 = extractvalue { i32, i32, i32, i32 } %4808, 0, !dbg !212 + %4810 = extractvalue { i32, i32, i32, i32 } %4808, 1, !dbg !212 + %4811 = extractvalue { i32, i32, i32, i32 } %4808, 2, !dbg !212 + %4812 = extractvalue { i32, i32, i32, i32 } %4808, 3, !dbg !212 + %4813 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4781, i1 %4789) #3, !dbg !212 + %4814 = extractvalue { i32, i32, i32, i32 } %4813, 0, !dbg !212 + %4815 = extractvalue { i32, i32, i32, i32 } %4813, 1, !dbg !212 + %4816 = extractvalue { i32, i32, i32, i32 } %4813, 2, !dbg !212 + %4817 = extractvalue { i32, i32, i32, i32 } %4813, 3, !dbg !212 + %4818 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4782, i1 %4790) #3, !dbg !212 + %4819 = extractvalue { i32, i32, i32, i32 } %4818, 0, !dbg !212 + %4820 = extractvalue { i32, i32, i32, i32 } %4818, 1, !dbg !212 + %4821 = extractvalue { i32, i32, i32, i32 } %4818, 2, !dbg !212 + %4822 = extractvalue { i32, i32, i32, i32 } %4818, 3, !dbg !212 + %4823 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4783, i1 %4791) #3, !dbg !212 + %4824 = extractvalue { i32, i32, i32, i32 } %4823, 0, !dbg !212 + %4825 = extractvalue { i32, i32, i32, i32 } %4823, 1, !dbg !212 + %4826 = extractvalue { i32, i32, i32, i32 } %4823, 2, !dbg !212 + %4827 = extractvalue { i32, i32, i32, i32 } %4823, 3, !dbg !212 + %4828 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4784, i1 %4792) #3, !dbg !212 + %4829 = extractvalue { i32, i32, i32, i32 } %4828, 0, !dbg !212 + %4830 = extractvalue { i32, i32, i32, i32 } %4828, 1, !dbg !212 + %4831 = extractvalue { i32, i32, i32, i32 } %4828, 2, !dbg !212 + %4832 = extractvalue { i32, i32, i32, i32 } %4828, 3, !dbg !212 + %4833 = shl nuw nsw i32 %50, 4, !dbg !212 + %4834 = and i32 %4833, 112, !dbg !212 + %4835 = shl nuw nsw i32 %52, 3, !dbg !212 + %4836 = and i32 %50, 112, !dbg !212 + %4837 = and i32 %50, 8, !dbg !212 + %4838 = shl nuw nsw i32 %4837, 11, !dbg !212 + %4839 = or disjoint i32 %4834, %4835, !dbg !212 + %4840 = xor i32 %4839, %4836, !dbg !212 + %4841 = or disjoint i32 %4840, %4838, !dbg !212 + %4842 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4841, !dbg !212 + %4843 = insertelement <4 x i32> poison, i32 %4794, i64 0, !dbg !212 + %4844 = insertelement <4 x i32> %4843, i32 %4795, i64 1, !dbg !212 + %4845 = insertelement <4 x i32> %4844, i32 %4796, i64 2, !dbg !212 + %4846 = insertelement <4 x i32> %4845, i32 %4797, i64 3, !dbg !212 + store <4 x i32> %4846, ptr addrspace(3) %4842, align 16, !dbg !212 + %4847 = or disjoint i32 %4841, 2048, !dbg !212 + %4848 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4847, !dbg !212 + %4849 = insertelement <4 x i32> poison, i32 %4799, i64 0, !dbg !212 + %4850 = insertelement <4 x i32> %4849, i32 %4800, i64 1, !dbg !212 + %4851 = insertelement <4 x i32> %4850, i32 %4801, i64 2, !dbg !212 + %4852 = insertelement <4 x i32> %4851, i32 %4802, i64 3, !dbg !212 + store <4 x i32> %4852, ptr addrspace(3) %4848, align 16, !dbg !212 + %4853 = or disjoint i32 %4841, 4096, !dbg !212 + %4854 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4853, !dbg !212 + %4855 = insertelement <4 x i32> poison, i32 %4804, i64 0, !dbg !212 + %4856 = insertelement <4 x i32> %4855, i32 %4805, i64 1, !dbg !212 + %4857 = insertelement <4 x i32> %4856, i32 %4806, i64 2, !dbg !212 + %4858 = insertelement <4 x i32> %4857, i32 %4807, i64 3, !dbg !212 + store <4 x i32> %4858, ptr addrspace(3) %4854, align 16, !dbg !212 + %4859 = or disjoint i32 %4841, 6144, !dbg !212 + %4860 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4859, !dbg !212 + %4861 = insertelement <4 x i32> poison, i32 %4809, i64 0, !dbg !212 + %4862 = insertelement <4 x i32> %4861, i32 %4810, i64 1, !dbg !212 + %4863 = insertelement <4 x i32> %4862, i32 %4811, i64 2, !dbg !212 + %4864 = insertelement <4 x i32> %4863, i32 %4812, i64 3, !dbg !212 + store <4 x i32> %4864, ptr addrspace(3) %4860, align 16, !dbg !212 + %4865 = or disjoint i32 %4841, 8192, !dbg !212 + %4866 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4865, !dbg !212 + %4867 = insertelement <4 x i32> poison, i32 %4814, i64 0, !dbg !212 + %4868 = insertelement <4 x i32> %4867, i32 %4815, i64 1, !dbg !212 + %4869 = insertelement <4 x i32> %4868, i32 %4816, i64 2, !dbg !212 + %4870 = insertelement <4 x i32> %4869, i32 %4817, i64 3, !dbg !212 + store <4 x i32> %4870, ptr addrspace(3) %4866, align 16, !dbg !212 + %4871 = or disjoint i32 %4841, 10240, !dbg !212 + %4872 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4871, !dbg !212 + %4873 = insertelement <4 x i32> poison, i32 %4819, i64 0, !dbg !212 + %4874 = insertelement <4 x i32> %4873, i32 %4820, i64 1, !dbg !212 + %4875 = insertelement <4 x i32> %4874, i32 %4821, i64 2, !dbg !212 + %4876 = insertelement <4 x i32> %4875, i32 %4822, i64 3, !dbg !212 + store <4 x i32> %4876, ptr addrspace(3) %4872, align 16, !dbg !212 + %4877 = or disjoint i32 %4841, 12288, !dbg !212 + %4878 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4877, !dbg !212 + %4879 = insertelement <4 x i32> poison, i32 %4824, i64 0, !dbg !212 + %4880 = insertelement <4 x i32> %4879, i32 %4825, i64 1, !dbg !212 + %4881 = insertelement <4 x i32> %4880, i32 %4826, i64 2, !dbg !212 + %4882 = insertelement <4 x i32> %4881, i32 %4827, i64 3, !dbg !212 + store <4 x i32> %4882, ptr addrspace(3) %4878, align 16, !dbg !212 + %4883 = or disjoint i32 %4841, 14336, !dbg !212 + %4884 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4883, !dbg !212 + %4885 = insertelement <4 x i32> poison, i32 %4829, i64 0, !dbg !212 + %4886 = insertelement <4 x i32> %4885, i32 %4830, i64 1, !dbg !212 + %4887 = insertelement <4 x i32> %4886, i32 %4831, i64 2, !dbg !212 + %4888 = insertelement <4 x i32> %4887, i32 %4832, i64 3, !dbg !212 + store <4 x i32> %4888, ptr addrspace(3) %4884, align 16, !dbg !212 + %4889 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4758, !dbg !213 + %4890 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4760, !dbg !213 + %4891 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4762, !dbg !213 + %4892 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4764, !dbg !213 + %4893 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4766, !dbg !213 + %4894 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4768, !dbg !213 + %4895 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4770, !dbg !213 + %4896 = getelementptr bfloat, ptr addrspace(1) %48, i64 %4772, !dbg !213 + %4897 = getelementptr bfloat, ptr addrspace(1) %4889, i64 %4776, !dbg !215 + %4898 = getelementptr bfloat, ptr addrspace(1) %4890, i64 %4776, !dbg !215 + %4899 = getelementptr bfloat, ptr addrspace(1) %4891, i64 %4776, !dbg !215 + %4900 = getelementptr bfloat, ptr addrspace(1) %4892, i64 %4776, !dbg !215 + %4901 = getelementptr bfloat, ptr addrspace(1) %4893, i64 %4776, !dbg !215 + %4902 = getelementptr bfloat, ptr addrspace(1) %4894, i64 %4776, !dbg !215 + %4903 = getelementptr bfloat, ptr addrspace(1) %4895, i64 %4776, !dbg !215 + %4904 = getelementptr bfloat, ptr addrspace(1) %4896, i64 %4776, !dbg !215 + %4905 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4897, i1 %4785) #3, !dbg !216 + %4906 = extractvalue { i32, i32, i32, i32 } %4905, 0, !dbg !216 + %4907 = extractvalue { i32, i32, i32, i32 } %4905, 1, !dbg !216 + %4908 = extractvalue { i32, i32, i32, i32 } %4905, 2, !dbg !216 + %4909 = extractvalue { i32, i32, i32, i32 } %4905, 3, !dbg !216 + %4910 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4898, i1 %4786) #3, !dbg !216 + %4911 = extractvalue { i32, i32, i32, i32 } %4910, 0, !dbg !216 + %4912 = extractvalue { i32, i32, i32, i32 } %4910, 1, !dbg !216 + %4913 = extractvalue { i32, i32, i32, i32 } %4910, 2, !dbg !216 + %4914 = extractvalue { i32, i32, i32, i32 } %4910, 3, !dbg !216 + %4915 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4899, i1 %4787) #3, !dbg !216 + %4916 = extractvalue { i32, i32, i32, i32 } %4915, 0, !dbg !216 + %4917 = extractvalue { i32, i32, i32, i32 } %4915, 1, !dbg !216 + %4918 = extractvalue { i32, i32, i32, i32 } %4915, 2, !dbg !216 + %4919 = extractvalue { i32, i32, i32, i32 } %4915, 3, !dbg !216 + %4920 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4900, i1 %4788) #3, !dbg !216 + %4921 = extractvalue { i32, i32, i32, i32 } %4920, 0, !dbg !216 + %4922 = extractvalue { i32, i32, i32, i32 } %4920, 1, !dbg !216 + %4923 = extractvalue { i32, i32, i32, i32 } %4920, 2, !dbg !216 + %4924 = extractvalue { i32, i32, i32, i32 } %4920, 3, !dbg !216 + %4925 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4901, i1 %4789) #3, !dbg !216 + %4926 = extractvalue { i32, i32, i32, i32 } %4925, 0, !dbg !216 + %4927 = extractvalue { i32, i32, i32, i32 } %4925, 1, !dbg !216 + %4928 = extractvalue { i32, i32, i32, i32 } %4925, 2, !dbg !216 + %4929 = extractvalue { i32, i32, i32, i32 } %4925, 3, !dbg !216 + %4930 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4902, i1 %4790) #3, !dbg !216 + %4931 = extractvalue { i32, i32, i32, i32 } %4930, 0, !dbg !216 + %4932 = extractvalue { i32, i32, i32, i32 } %4930, 1, !dbg !216 + %4933 = extractvalue { i32, i32, i32, i32 } %4930, 2, !dbg !216 + %4934 = extractvalue { i32, i32, i32, i32 } %4930, 3, !dbg !216 + %4935 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4903, i1 %4791) #3, !dbg !216 + %4936 = extractvalue { i32, i32, i32, i32 } %4935, 0, !dbg !216 + %4937 = extractvalue { i32, i32, i32, i32 } %4935, 1, !dbg !216 + %4938 = extractvalue { i32, i32, i32, i32 } %4935, 2, !dbg !216 + %4939 = extractvalue { i32, i32, i32, i32 } %4935, 3, !dbg !216 + %4940 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4904, i1 %4792) #3, !dbg !216 + %4941 = extractvalue { i32, i32, i32, i32 } %4940, 0, !dbg !216 + %4942 = extractvalue { i32, i32, i32, i32 } %4940, 1, !dbg !216 + %4943 = extractvalue { i32, i32, i32, i32 } %4940, 2, !dbg !216 + %4944 = extractvalue { i32, i32, i32, i32 } %4940, 3, !dbg !216 + %4945 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4841, !dbg !216 + %4946 = insertelement <4 x i32> poison, i32 %4906, i64 0, !dbg !216 + %4947 = insertelement <4 x i32> %4946, i32 %4907, i64 1, !dbg !216 + %4948 = insertelement <4 x i32> %4947, i32 %4908, i64 2, !dbg !216 + %4949 = insertelement <4 x i32> %4948, i32 %4909, i64 3, !dbg !216 + store <4 x i32> %4949, ptr addrspace(3) %4945, align 16, !dbg !216 + %4950 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4847, !dbg !216 + %4951 = insertelement <4 x i32> poison, i32 %4911, i64 0, !dbg !216 + %4952 = insertelement <4 x i32> %4951, i32 %4912, i64 1, !dbg !216 + %4953 = insertelement <4 x i32> %4952, i32 %4913, i64 2, !dbg !216 + %4954 = insertelement <4 x i32> %4953, i32 %4914, i64 3, !dbg !216 + store <4 x i32> %4954, ptr addrspace(3) %4950, align 16, !dbg !216 + %4955 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4853, !dbg !216 + %4956 = insertelement <4 x i32> poison, i32 %4916, i64 0, !dbg !216 + %4957 = insertelement <4 x i32> %4956, i32 %4917, i64 1, !dbg !216 + %4958 = insertelement <4 x i32> %4957, i32 %4918, i64 2, !dbg !216 + %4959 = insertelement <4 x i32> %4958, i32 %4919, i64 3, !dbg !216 + store <4 x i32> %4959, ptr addrspace(3) %4955, align 16, !dbg !216 + %4960 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4859, !dbg !216 + %4961 = insertelement <4 x i32> poison, i32 %4921, i64 0, !dbg !216 + %4962 = insertelement <4 x i32> %4961, i32 %4922, i64 1, !dbg !216 + %4963 = insertelement <4 x i32> %4962, i32 %4923, i64 2, !dbg !216 + %4964 = insertelement <4 x i32> %4963, i32 %4924, i64 3, !dbg !216 + store <4 x i32> %4964, ptr addrspace(3) %4960, align 16, !dbg !216 + %4965 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4865, !dbg !216 + %4966 = insertelement <4 x i32> poison, i32 %4926, i64 0, !dbg !216 + %4967 = insertelement <4 x i32> %4966, i32 %4927, i64 1, !dbg !216 + %4968 = insertelement <4 x i32> %4967, i32 %4928, i64 2, !dbg !216 + %4969 = insertelement <4 x i32> %4968, i32 %4929, i64 3, !dbg !216 + store <4 x i32> %4969, ptr addrspace(3) %4965, align 16, !dbg !216 + %4970 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4871, !dbg !216 + %4971 = insertelement <4 x i32> poison, i32 %4931, i64 0, !dbg !216 + %4972 = insertelement <4 x i32> %4971, i32 %4932, i64 1, !dbg !216 + %4973 = insertelement <4 x i32> %4972, i32 %4933, i64 2, !dbg !216 + %4974 = insertelement <4 x i32> %4973, i32 %4934, i64 3, !dbg !216 + store <4 x i32> %4974, ptr addrspace(3) %4970, align 16, !dbg !216 + %4975 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4877, !dbg !216 + %4976 = insertelement <4 x i32> poison, i32 %4936, i64 0, !dbg !216 + %4977 = insertelement <4 x i32> %4976, i32 %4937, i64 1, !dbg !216 + %4978 = insertelement <4 x i32> %4977, i32 %4938, i64 2, !dbg !216 + %4979 = insertelement <4 x i32> %4978, i32 %4939, i64 3, !dbg !216 + store <4 x i32> %4979, ptr addrspace(3) %4975, align 16, !dbg !216 + %4980 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4883, !dbg !216 + %4981 = insertelement <4 x i32> poison, i32 %4941, i64 0, !dbg !216 + %4982 = insertelement <4 x i32> %4981, i32 %4942, i64 1, !dbg !216 + %4983 = insertelement <4 x i32> %4982, i32 %4943, i64 2, !dbg !216 + %4984 = insertelement <4 x i32> %4983, i32 %4944, i64 3, !dbg !216 + store <4 x i32> %4984, ptr addrspace(3) %4980, align 16, !dbg !216 + %4985 = shl nuw nsw i32 %40, 2, !dbg !217 + %4986 = mul i32 %28, %39, !dbg !218 + %4987 = mul i32 %34, %39, !dbg !219 + %4988 = shl nuw nsw i32 %39, 5, !dbg !220 + %4989 = mul i32 %23, %36, !dbg !221 + %4990 = sext i32 %4989 to i64, !dbg !222 + %4991 = getelementptr i32, ptr addrspace(1) %11, i64 %4990, !dbg !222 + %4992 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4991, i1 true) #3, !dbg !223 + %4993 = shl i32 %4992, 7, !dbg !224 + %4994 = zext nneg i32 %36 to i64, !dbg !225 + %4995 = getelementptr i32, ptr addrspace(1) %10, i64 %4994, !dbg !225 + %4996 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4995, i1 true) #3, !dbg !226 + %4997 = and i32 %50, 3, !dbg !227 + %4998 = shl nuw nsw i32 %4997, 1, !dbg !227 + %4999 = or disjoint i32 %4998, 1, !dbg !227 + %5000 = or disjoint i32 %4998, 8, !dbg !227 + %5001 = or disjoint i32 %4998, 9, !dbg !227 + %5002 = or disjoint i32 %4993, %53, !dbg !228 + %5003 = or disjoint i32 %4993, %54, !dbg !228 + %5004 = or disjoint i32 %4993, %55, !dbg !228 + %5005 = or disjoint i32 %4993, %56, !dbg !228 + %5006 = shl i32 %5002, 12, !dbg !229 + %5007 = shl i32 %5003, 12, !dbg !229 + %5008 = shl i32 %5004, 12, !dbg !229 + %5009 = shl i32 %5005, 12, !dbg !229 + %5010 = shl i32 %5002, 7, !dbg !231 + %5011 = shl i32 %5003, 7, !dbg !231 + %5012 = shl i32 %5004, 7, !dbg !231 + %5013 = shl i32 %5005, 7, !dbg !231 + %5014 = shl i32 %4996, 1, !dbg !232 + %5015 = add i32 %17, 63, !dbg !233 + %5016 = sdiv i32 %5015, 64, !dbg !234 + %5017 = tail call i32 @llvm.smax.i32(i32 %5016, i32 1), !dbg !235 + %5018 = tail call i32 @llvm.smin.i32(i32 %5014, i32 %5017), !dbg !236 + %5019 = insertelement <2 x i32> poison, i32 %66, i64 0, !dbg !205 + %5020 = insertelement <2 x i32> %5019, i32 %65, i64 1, !dbg !205 + %5021 = insertelement <2 x i32> poison, i32 %4741, i64 0, !dbg !205 + %5022 = shufflevector <2 x i32> %5021, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !205 + %5023 = or disjoint <2 x i32> %5020, %5022, !dbg !205 + %5024 = insertelement <4 x i32> poison, i32 %4998, i64 0, !dbg !227 + %5025 = shufflevector <4 x i32> %5024, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !227 + %5026 = or disjoint <4 x i32> %5025, , !dbg !227 + %5027 = insertelement <8 x i32> poison, i32 %4998, i64 0, !dbg !227 + %5028 = shufflevector <8 x i32> %5027, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !227 + %5029 = or disjoint <8 x i32> %5028, , !dbg !227 + %5030 = insertelement <16 x i32> poison, i32 %4993, i64 0, !dbg !228 + %5031 = shufflevector <16 x i32> %5030, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !228 + %5032 = insertelement <16 x i32> poison, i32 %5001, i64 12, !dbg !228 + %5033 = insertelement <16 x i32> %5032, i32 %5000, i64 13, !dbg !228 + %5034 = insertelement <16 x i32> %5033, i32 %4999, i64 14, !dbg !228 + %5035 = insertelement <16 x i32> %5034, i32 %4998, i64 15, !dbg !228 + %5036 = shufflevector <8 x i32> %5029, <8 x i32> poison, <16 x i32> , !dbg !228 + %5037 = shufflevector <16 x i32> %5036, <16 x i32> %5035, <16 x i32> , !dbg !228 + %5038 = shufflevector <4 x i32> %5026, <4 x i32> poison, <16 x i32> , !dbg !228 + %5039 = shufflevector <16 x i32> %5037, <16 x i32> %5038, <16 x i32> , !dbg !228 + %5040 = or disjoint <16 x i32> %5031, %5039, !dbg !228 + %5041 = insertelement <2 x i32> poison, i32 %18, i64 0, !dbg !237 + %5042 = shufflevector <2 x i32> %5041, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !237 + %5043 = srem <2 x i32> %5023, %5042, !dbg !237 + %5044 = lshr <2 x i32> %5043, splat (i32 4), !dbg !238 + %5045 = shufflevector <2 x i32> %5044, <2 x i32> poison, <32 x i32> , !dbg !238 + %5046 = getelementptr i32, ptr addrspace(1) %15, i64 %4990, !dbg !239 + %5047 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %5046, i1 true) #3, !dbg !240 + %5048 = shl i32 %5047, 7, !dbg !241 + %5049 = getelementptr i32, ptr addrspace(1) %14, i64 %4994, !dbg !242 + %5050 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %5049, i1 true) #3, !dbg !243 + %5051 = or disjoint i32 %5048, %4998, !dbg !244 + %5052 = or disjoint i32 %5048, %4999, !dbg !244 + %5053 = or disjoint i32 %5048, %5000, !dbg !244 + %5054 = or disjoint i32 %5048, %5001, !dbg !244 + %5055 = extractelement <4 x i32> %5026, i64 3, !dbg !244 + %5056 = or disjoint i32 %5048, %5055, !dbg !244 + %5057 = extractelement <4 x i32> %5026, i64 2, !dbg !244 + %5058 = or disjoint i32 %5048, %5057, !dbg !244 + %5059 = extractelement <4 x i32> %5026, i64 1, !dbg !244 + %5060 = or disjoint i32 %5048, %5059, !dbg !244 + %5061 = extractelement <4 x i32> %5026, i64 0, !dbg !244 + %5062 = or disjoint i32 %5048, %5061, !dbg !244 + %5063 = extractelement <8 x i32> %5029, i64 7, !dbg !244 + %5064 = or disjoint i32 %5048, %5063, !dbg !244 + %5065 = extractelement <8 x i32> %5029, i64 6, !dbg !244 + %5066 = or disjoint i32 %5048, %5065, !dbg !244 + %5067 = extractelement <8 x i32> %5029, i64 5, !dbg !244 + %5068 = or disjoint i32 %5048, %5067, !dbg !244 + %5069 = extractelement <8 x i32> %5029, i64 4, !dbg !244 + %5070 = or disjoint i32 %5048, %5069, !dbg !244 + %5071 = extractelement <8 x i32> %5029, i64 3, !dbg !244 + %5072 = or disjoint i32 %5048, %5071, !dbg !244 + %5073 = extractelement <8 x i32> %5029, i64 2, !dbg !244 + %5074 = or disjoint i32 %5048, %5073, !dbg !244 + %5075 = extractelement <8 x i32> %5029, i64 1, !dbg !244 + %5076 = or disjoint i32 %5048, %5075, !dbg !244 + %5077 = extractelement <8 x i32> %5029, i64 0, !dbg !244 + %5078 = or disjoint i32 %5048, %5077, !dbg !244 + %5079 = or disjoint i32 %5048, %53, !dbg !244 + %5080 = or disjoint i32 %5048, %54, !dbg !244 + %5081 = or disjoint i32 %5048, %55, !dbg !244 + %5082 = or disjoint i32 %5048, %56, !dbg !244 + %5083 = shl i32 %5079, 12, !dbg !245 + %5084 = shl i32 %5080, 12, !dbg !245 + %5085 = shl i32 %5081, 12, !dbg !245 + %5086 = shl i32 %5082, 12, !dbg !245 + %5087 = shl i32 %5079, 7, !dbg !247 + %5088 = shl i32 %5080, 7, !dbg !247 + %5089 = shl i32 %5081, 7, !dbg !247 + %5090 = shl i32 %5082, 7, !dbg !247 + %5091 = shl i32 %5050, 1, !dbg !248 + %5092 = tail call i32 @llvm.smin.i32(i32 %5091, i32 %5017), !dbg !249 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !250 + %5093 = sext i32 %5006 to i64 + %5094 = sext i32 %5007 to i64 + %5095 = sext i32 %5008 to i64 + %5096 = sext i32 %5009 to i64 + %5097 = sext i32 %5010 to i64 + %5098 = sext i32 %5011 to i64 + %5099 = sext i32 %5012 to i64 + %5100 = sext i32 %5013 to i64 + %5101 = icmp sgt i32 %5014, 0 + %5102 = icmp slt i32 %5002, %17 + %5103 = icmp slt i32 %5003, %17 + %5104 = icmp slt i32 %5004, %17 + %5105 = icmp slt i32 %5005, %17 + %5106 = and i1 %5101, %5102 + %5107 = and i1 %5101, %5103 + %5108 = and i1 %5101, %5104 + %5109 = and i1 %5101, %5105 + %5110 = shl nuw nsw i32 %4837, 10 + %5111 = or disjoint i32 %4840, %5110 + %5112 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5111 + %5113 = select i1 %5106, i32 16, i32 0 + %5114 = or disjoint i32 %5111, 2048 + %5115 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5114 + %5116 = select i1 %5107, i32 16, i32 0 + %5117 = or disjoint i32 %5111, 4096 + %5118 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5117 + %5119 = select i1 %5108, i32 16, i32 0 + %5120 = or disjoint i32 %5111, 6144 + %5121 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5120 + %5122 = select i1 %5109, i32 16, i32 0 + %5123 = extractelement <16 x i32> %5040, i64 15 + %5124 = icmp slt i32 %5123, %17 + %5125 = extractelement <16 x i32> %5040, i64 14 + %5126 = icmp slt i32 %5125, %17 + %5127 = extractelement <16 x i32> %5040, i64 13 + %5128 = icmp slt i32 %5127, %17 + %5129 = extractelement <16 x i32> %5040, i64 12 + %5130 = icmp slt i32 %5129, %17 + %5131 = extractelement <16 x i32> %5040, i64 11 + %5132 = icmp slt i32 %5131, %17 + %5133 = extractelement <16 x i32> %5040, i64 10 + %5134 = icmp slt i32 %5133, %17 + %5135 = extractelement <16 x i32> %5040, i64 9 + %5136 = icmp slt i32 %5135, %17 + %5137 = extractelement <16 x i32> %5040, i64 8 + %5138 = icmp slt i32 %5137, %17 + %5139 = extractelement <16 x i32> %5040, i64 7 + %5140 = icmp slt i32 %5139, %17 + %5141 = extractelement <16 x i32> %5040, i64 6 + %5142 = icmp slt i32 %5141, %17 + %5143 = extractelement <16 x i32> %5040, i64 5 + %5144 = icmp slt i32 %5143, %17 + %5145 = extractelement <16 x i32> %5040, i64 4 + %5146 = icmp slt i32 %5145, %17 + %5147 = extractelement <16 x i32> %5040, i64 3 + %5148 = icmp slt i32 %5147, %17 + %5149 = extractelement <16 x i32> %5040, i64 2 + %5150 = icmp slt i32 %5149, %17 + %5151 = extractelement <16 x i32> %5040, i64 1 + %5152 = icmp slt i32 %5151, %17 + %5153 = extractelement <16 x i32> %5040, i64 0 + %5154 = icmp slt i32 %5153, %17 + %5155 = sext i32 %5123 to i64 + %5156 = sext i32 %5125 to i64 + %5157 = sext i32 %5127 to i64 + %5158 = sext i32 %5129 to i64 + %5159 = sext i32 %5131 to i64 + %5160 = sext i32 %5133 to i64 + %5161 = sext i32 %5135 to i64 + %5162 = sext i32 %5137 to i64 + %5163 = sext i32 %5139 to i64 + %5164 = sext i32 %5141 to i64 + %5165 = sext i32 %5143 to i64 + %5166 = sext i32 %5145 to i64 + %5167 = sext i32 %5147 to i64 + %5168 = sext i32 %5149 to i64 + %5169 = sext i32 %5151 to i64 + %5170 = sext i32 %5153 to i64 + %5171 = and i1 %5101, %5124 + %5172 = and i1 %5101, %5126 + %5173 = and i1 %5101, %5128 + %5174 = and i1 %5101, %5130 + %5175 = and i1 %5101, %5132 + %5176 = and i1 %5101, %5134 + %5177 = and i1 %5101, %5136 + %5178 = and i1 %5101, %5138 + %5179 = and i1 %5101, %5140 + %5180 = and i1 %5101, %5142 + %5181 = and i1 %5101, %5144 + %5182 = and i1 %5101, %5146 + %5183 = and i1 %5101, %5148 + %5184 = and i1 %5101, %5150 + %5185 = and i1 %5101, %5152 + %5186 = and i1 %5101, %5154 + %5187 = and i32 %50, 252 + %5188 = icmp eq i32 %5187, 0 + %5189 = shl nuw nsw i32 %4997, 3 + %5190 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5189 + %5191 = select i1 %5171, i32 4, i32 0 + %5192 = or disjoint i32 %5189, 4 + %5193 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5192 + %5194 = select i1 %5172, i32 4, i32 0 + %5195 = or disjoint i32 %5189, 32 + %5196 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5195 + %5197 = select i1 %5173, i32 4, i32 0 + %5198 = or disjoint i32 %5189, 36 + %5199 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5198 + %5200 = select i1 %5174, i32 4, i32 0 + %5201 = or disjoint i32 %5189, 64 + %5202 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5201 + %5203 = select i1 %5175, i32 4, i32 0 + %5204 = or disjoint i32 %5189, 68 + %5205 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5204 + %5206 = select i1 %5176, i32 4, i32 0 + %5207 = or disjoint i32 %5189, 96 + %5208 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5207 + %5209 = select i1 %5177, i32 4, i32 0 + %5210 = or disjoint i32 %5189, 100 + %5211 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5210 + %5212 = select i1 %5178, i32 4, i32 0 + %5213 = or disjoint i32 %5189, 128 + %5214 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5213 + %5215 = select i1 %5179, i32 4, i32 0 + %5216 = or disjoint i32 %5189, 132 + %5217 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5216 + %5218 = select i1 %5180, i32 4, i32 0 + %5219 = or disjoint i32 %5189, 160 + %5220 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5219 + %5221 = select i1 %5181, i32 4, i32 0 + %5222 = or disjoint i32 %5189, 164 + %5223 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5222 + %5224 = select i1 %5182, i32 4, i32 0 + %5225 = or disjoint i32 %5189, 192 + %5226 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5225 + %5227 = select i1 %5183, i32 4, i32 0 + %5228 = or disjoint i32 %5189, 196 + %5229 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5228 + %5230 = select i1 %5184, i32 4, i32 0 + %5231 = or disjoint i32 %5189, 224 + %5232 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5231 + %5233 = select i1 %5185, i32 4, i32 0 + %5234 = or disjoint i32 %5189, 228 + %5235 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5234 + %5236 = select i1 %5186, i32 4, i32 0 + %5237 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5111 + %5238 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5114 + %5239 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5117 + %5240 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5120 + %5241 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5189 + %5242 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5192 + %5243 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5195 + %5244 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5198 + %5245 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5201 + %5246 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5204 + %5247 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5207 + %5248 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5210 + %5249 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5213 + %5250 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5216 + %5251 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5219 + %5252 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5222 + %5253 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5225 + %5254 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5228 + %5255 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5231 + %5256 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5234 + %5257 = icmp sgt i32 %5018, 1 + %5258 = or disjoint i32 %5123, 64 + %5259 = or disjoint i32 %5125, 64 + %5260 = or disjoint i32 %5127, 64 + %5261 = or disjoint i32 %5129, 64 + %5262 = or disjoint i32 %5131, 64 + %5263 = or disjoint i32 %5133, 64 + %5264 = or disjoint i32 %5135, 64 + %5265 = or disjoint i32 %5137, 64 + %5266 = or disjoint i32 %5139, 64 + %5267 = or disjoint i32 %5141, 64 + %5268 = or disjoint i32 %5143, 64 + %5269 = or disjoint i32 %5145, 64 + %5270 = or disjoint i32 %5147, 64 + %5271 = or disjoint i32 %5149, 64 + %5272 = or disjoint i32 %5151, 64 + %5273 = or disjoint i32 %5153, 64 + %5274 = or disjoint i32 %5002, 64 + %5275 = or disjoint i32 %5003, 64 + %5276 = or disjoint i32 %5004, 64 + %5277 = or disjoint i32 %5005, 64 + %5278 = icmp slt i32 %5274, %17 + %5279 = icmp slt i32 %5275, %17 + %5280 = icmp slt i32 %5276, %17 + %5281 = icmp slt i32 %5277, %17 + %5282 = and i1 %5257, %5278 + %5283 = and i1 %5257, %5279 + %5284 = and i1 %5257, %5280 + %5285 = and i1 %5257, %5281 + %5286 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5111 + %5287 = select i1 %5282, i32 16, i32 0 + %5288 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5114 + %5289 = select i1 %5283, i32 16, i32 0 + %5290 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5117 + %5291 = select i1 %5284, i32 16, i32 0 + %5292 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5120 + %5293 = select i1 %5285, i32 16, i32 0 + %5294 = icmp slt i32 %5258, %17 + %5295 = icmp slt i32 %5259, %17 + %5296 = icmp slt i32 %5260, %17 + %5297 = icmp slt i32 %5261, %17 + %5298 = icmp slt i32 %5262, %17 + %5299 = icmp slt i32 %5263, %17 + %5300 = icmp slt i32 %5264, %17 + %5301 = icmp slt i32 %5265, %17 + %5302 = icmp slt i32 %5266, %17 + %5303 = icmp slt i32 %5267, %17 + %5304 = icmp slt i32 %5268, %17 + %5305 = icmp slt i32 %5269, %17 + %5306 = icmp slt i32 %5270, %17 + %5307 = icmp slt i32 %5271, %17 + %5308 = icmp slt i32 %5272, %17 + %5309 = icmp slt i32 %5273, %17 + %5310 = sext i32 %5258 to i64 + %5311 = sext i32 %5259 to i64 + %5312 = sext i32 %5260 to i64 + %5313 = sext i32 %5261 to i64 + %5314 = sext i32 %5262 to i64 + %5315 = sext i32 %5263 to i64 + %5316 = sext i32 %5264 to i64 + %5317 = sext i32 %5265 to i64 + %5318 = sext i32 %5266 to i64 + %5319 = sext i32 %5267 to i64 + %5320 = sext i32 %5268 to i64 + %5321 = sext i32 %5269 to i64 + %5322 = sext i32 %5270 to i64 + %5323 = sext i32 %5271 to i64 + %5324 = sext i32 %5272 to i64 + %5325 = sext i32 %5273 to i64 + %5326 = and i1 %5257, %5294 + %5327 = and i1 %5257, %5295 + %5328 = and i1 %5257, %5296 + %5329 = and i1 %5257, %5297 + %5330 = and i1 %5257, %5298 + %5331 = and i1 %5257, %5299 + %5332 = and i1 %5257, %5300 + %5333 = and i1 %5257, %5301 + %5334 = and i1 %5257, %5302 + %5335 = and i1 %5257, %5303 + %5336 = and i1 %5257, %5304 + %5337 = and i1 %5257, %5305 + %5338 = and i1 %5257, %5306 + %5339 = and i1 %5257, %5307 + %5340 = and i1 %5257, %5308 + %5341 = and i1 %5257, %5309 + %5342 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5189 + %5343 = select i1 %5326, i32 4, i32 0 + %5344 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5192 + %5345 = select i1 %5327, i32 4, i32 0 + %5346 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5195 + %5347 = select i1 %5328, i32 4, i32 0 + %5348 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5198 + %5349 = select i1 %5329, i32 4, i32 0 + %5350 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5201 + %5351 = select i1 %5330, i32 4, i32 0 + %5352 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5204 + %5353 = select i1 %5331, i32 4, i32 0 + %5354 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5207 + %5355 = select i1 %5332, i32 4, i32 0 + %5356 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5210 + %5357 = select i1 %5333, i32 4, i32 0 + %5358 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5213 + %5359 = select i1 %5334, i32 4, i32 0 + %5360 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5216 + %5361 = select i1 %5335, i32 4, i32 0 + %5362 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5219 + %5363 = select i1 %5336, i32 4, i32 0 + %5364 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5222 + %5365 = select i1 %5337, i32 4, i32 0 + %5366 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5225 + %5367 = select i1 %5338, i32 4, i32 0 + %5368 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5228 + %5369 = select i1 %5339, i32 4, i32 0 + %5370 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5231 + %5371 = select i1 %5340, i32 4, i32 0 + %5372 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5234 + %5373 = select i1 %5341, i32 4, i32 0 + %5374 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5111 + %5375 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5114 + %5376 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5117 + %5377 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5120 + %5378 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5189 + %5379 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5192 + %5380 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5195 + %5381 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5198 + %5382 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5201 + %5383 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5204 + %5384 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5207 + %5385 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5210 + %5386 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5213 + %5387 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5216 + %5388 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5219 + %5389 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5222 + %5390 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5225 + %5391 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5228 + %5392 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5231 + %5393 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5234 + %5394 = add i32 %5018, -2 + %5395 = add nsw i32 %5018, -1 + %5396 = sext i32 %5083 to i64 + %5397 = sext i32 %5084 to i64 + %5398 = sext i32 %5085 to i64 + %5399 = sext i32 %5086 to i64 + %5400 = sext i32 %5087 to i64 + %5401 = sext i32 %5088 to i64 + %5402 = sext i32 %5089 to i64 + %5403 = sext i32 %5090 to i64 + %5404 = icmp sgt i32 %5091, 0 + %5405 = icmp slt i32 %5079, %17 + %5406 = icmp slt i32 %5080, %17 + %5407 = icmp slt i32 %5081, %17 + %5408 = icmp slt i32 %5082, %17 + %5409 = and i1 %5404, %5405 + %5410 = and i1 %5404, %5406 + %5411 = and i1 %5404, %5407 + %5412 = and i1 %5404, %5408 + %5413 = select i1 %5409, i32 16, i32 0 + %5414 = select i1 %5410, i32 16, i32 0 + %5415 = select i1 %5411, i32 16, i32 0 + %5416 = select i1 %5412, i32 16, i32 0 + %5417 = icmp slt i32 %5051, %17 + %5418 = icmp slt i32 %5052, %17 + %5419 = icmp slt i32 %5053, %17 + %5420 = icmp slt i32 %5054, %17 + %5421 = icmp slt i32 %5056, %17 + %5422 = icmp slt i32 %5058, %17 + %5423 = icmp slt i32 %5060, %17 + %5424 = icmp slt i32 %5062, %17 + %5425 = icmp slt i32 %5064, %17 + %5426 = icmp slt i32 %5066, %17 + %5427 = icmp slt i32 %5068, %17 + %5428 = icmp slt i32 %5070, %17 + %5429 = icmp slt i32 %5072, %17 + %5430 = icmp slt i32 %5074, %17 + %5431 = icmp slt i32 %5076, %17 + %5432 = icmp slt i32 %5078, %17 + %5433 = sext i32 %5051 to i64 + %5434 = sext i32 %5052 to i64 + %5435 = sext i32 %5053 to i64 + %5436 = sext i32 %5054 to i64 + %5437 = sext i32 %5056 to i64 + %5438 = sext i32 %5058 to i64 + %5439 = sext i32 %5060 to i64 + %5440 = sext i32 %5062 to i64 + %5441 = sext i32 %5064 to i64 + %5442 = sext i32 %5066 to i64 + %5443 = sext i32 %5068 to i64 + %5444 = sext i32 %5070 to i64 + %5445 = sext i32 %5072 to i64 + %5446 = sext i32 %5074 to i64 + %5447 = sext i32 %5076 to i64 + %5448 = sext i32 %5078 to i64 + %5449 = and i1 %5404, %5417 + %5450 = and i1 %5404, %5418 + %5451 = and i1 %5404, %5419 + %5452 = and i1 %5404, %5420 + %5453 = and i1 %5404, %5421 + %5454 = and i1 %5404, %5422 + %5455 = and i1 %5404, %5423 + %5456 = and i1 %5404, %5424 + %5457 = and i1 %5404, %5425 + %5458 = and i1 %5404, %5426 + %5459 = and i1 %5404, %5427 + %5460 = and i1 %5404, %5428 + %5461 = and i1 %5404, %5429 + %5462 = and i1 %5404, %5430 + %5463 = and i1 %5404, %5431 + %5464 = and i1 %5404, %5432 + %5465 = select i1 %5449, i32 4, i32 0 + %5466 = select i1 %5450, i32 4, i32 0 + %5467 = select i1 %5451, i32 4, i32 0 + %5468 = select i1 %5452, i32 4, i32 0 + %5469 = select i1 %5453, i32 4, i32 0 + %5470 = select i1 %5454, i32 4, i32 0 + %5471 = select i1 %5455, i32 4, i32 0 + %5472 = select i1 %5456, i32 4, i32 0 + %5473 = select i1 %5457, i32 4, i32 0 + %5474 = select i1 %5458, i32 4, i32 0 + %5475 = select i1 %5459, i32 4, i32 0 + %5476 = select i1 %5460, i32 4, i32 0 + %5477 = select i1 %5461, i32 4, i32 0 + %5478 = select i1 %5462, i32 4, i32 0 + %5479 = select i1 %5463, i32 4, i32 0 + %5480 = select i1 %5464, i32 4, i32 0 + %5481 = icmp sgt i32 %5092, 1 + %5482 = or disjoint i32 %5051, 64 + %5483 = or disjoint i32 %5052, 64 + %5484 = or disjoint i32 %5053, 64 + %5485 = or disjoint i32 %5054, 64 + %5486 = or disjoint i32 %5056, 64 + %5487 = or disjoint i32 %5058, 64 + %5488 = or disjoint i32 %5060, 64 + %5489 = or disjoint i32 %5062, 64 + %5490 = or disjoint i32 %5064, 64 + %5491 = or disjoint i32 %5066, 64 + %5492 = or disjoint i32 %5068, 64 + %5493 = or disjoint i32 %5070, 64 + %5494 = or disjoint i32 %5072, 64 + %5495 = or disjoint i32 %5074, 64 + %5496 = or disjoint i32 %5076, 64 + %5497 = or disjoint i32 %5078, 64 + %5498 = or disjoint i32 %5079, 64 + %5499 = or disjoint i32 %5080, 64 + %5500 = or disjoint i32 %5081, 64 + %5501 = or disjoint i32 %5082, 64 + %5502 = icmp slt i32 %5498, %17 + %5503 = icmp slt i32 %5499, %17 + %5504 = icmp slt i32 %5500, %17 + %5505 = icmp slt i32 %5501, %17 + %5506 = and i1 %5481, %5502 + %5507 = and i1 %5481, %5503 + %5508 = and i1 %5481, %5504 + %5509 = and i1 %5481, %5505 + %5510 = select i1 %5506, i32 16, i32 0 + %5511 = select i1 %5507, i32 16, i32 0 + %5512 = select i1 %5508, i32 16, i32 0 + %5513 = select i1 %5509, i32 16, i32 0 + %5514 = icmp slt i32 %5482, %17 + %5515 = icmp slt i32 %5483, %17 + %5516 = icmp slt i32 %5484, %17 + %5517 = icmp slt i32 %5485, %17 + %5518 = icmp slt i32 %5486, %17 + %5519 = icmp slt i32 %5487, %17 + %5520 = icmp slt i32 %5488, %17 + %5521 = icmp slt i32 %5489, %17 + %5522 = icmp slt i32 %5490, %17 + %5523 = icmp slt i32 %5491, %17 + %5524 = icmp slt i32 %5492, %17 + %5525 = icmp slt i32 %5493, %17 + %5526 = icmp slt i32 %5494, %17 + %5527 = icmp slt i32 %5495, %17 + %5528 = icmp slt i32 %5496, %17 + %5529 = icmp slt i32 %5497, %17 + %5530 = sext i32 %5482 to i64 + %5531 = sext i32 %5483 to i64 + %5532 = sext i32 %5484 to i64 + %5533 = sext i32 %5485 to i64 + %5534 = sext i32 %5486 to i64 + %5535 = sext i32 %5487 to i64 + %5536 = sext i32 %5488 to i64 + %5537 = sext i32 %5489 to i64 + %5538 = sext i32 %5490 to i64 + %5539 = sext i32 %5491 to i64 + %5540 = sext i32 %5492 to i64 + %5541 = sext i32 %5493 to i64 + %5542 = sext i32 %5494 to i64 + %5543 = sext i32 %5495 to i64 + %5544 = sext i32 %5496 to i64 + %5545 = sext i32 %5497 to i64 + %5546 = and i1 %5481, %5514 + %5547 = and i1 %5481, %5515 + %5548 = and i1 %5481, %5516 + %5549 = and i1 %5481, %5517 + %5550 = and i1 %5481, %5518 + %5551 = and i1 %5481, %5519 + %5552 = and i1 %5481, %5520 + %5553 = and i1 %5481, %5521 + %5554 = and i1 %5481, %5522 + %5555 = and i1 %5481, %5523 + %5556 = and i1 %5481, %5524 + %5557 = and i1 %5481, %5525 + %5558 = and i1 %5481, %5526 + %5559 = and i1 %5481, %5527 + %5560 = and i1 %5481, %5528 + %5561 = and i1 %5481, %5529 + %5562 = select i1 %5546, i32 4, i32 0 + %5563 = select i1 %5547, i32 4, i32 0 + %5564 = select i1 %5548, i32 4, i32 0 + %5565 = select i1 %5549, i32 4, i32 0 + %5566 = select i1 %5550, i32 4, i32 0 + %5567 = select i1 %5551, i32 4, i32 0 + %5568 = select i1 %5552, i32 4, i32 0 + %5569 = select i1 %5553, i32 4, i32 0 + %5570 = select i1 %5554, i32 4, i32 0 + %5571 = select i1 %5555, i32 4, i32 0 + %5572 = select i1 %5556, i32 4, i32 0 + %5573 = select i1 %5557, i32 4, i32 0 + %5574 = select i1 %5558, i32 4, i32 0 + %5575 = select i1 %5559, i32 4, i32 0 + %5576 = select i1 %5560, i32 4, i32 0 + %5577 = select i1 %5561, i32 4, i32 0 + %5578 = add i32 %5092, -2 + %5579 = add nsw i32 %5092, -1 + %smax2265 = tail call i32 @llvm.smax.i32(i32 %5018, i32 1), !dbg !251 + %smax2267 = tail call i32 @llvm.smax.i32(i32 %5092, i32 1), !dbg !251 + %5580 = zext nneg i32 %4985 to i64, !dbg !251 + %5581 = insertelement <16 x i32> poison, i32 %17, i64 0 + %5582 = shufflevector <16 x i32> %5581, <16 x i32> poison, <16 x i32> zeroinitializer + %5583 = extractelement <2 x i32> %5043, i64 1 + %5584 = extractelement <2 x i32> %5043, i64 0 + br label %5585, !dbg !251 + +5585: ; preds = %4740, %._crit_edge1874 + %indvars.iv = phi i64 [ 0, %4740 ], [ %indvars.iv.next, %._crit_edge1874 ] + %5586 = phi float [ 0.000000e+00, %4740 ], [ %11045, %._crit_edge1874 ] + %5587 = phi float [ 0.000000e+00, %4740 ], [ %11046, %._crit_edge1874 ] + %5588 = phi float [ 0.000000e+00, %4740 ], [ %11047, %._crit_edge1874 ] + %5589 = phi float [ 0.000000e+00, %4740 ], [ %11048, %._crit_edge1874 ] + %5590 = phi float [ 0.000000e+00, %4740 ], [ %11049, %._crit_edge1874 ] + %5591 = phi float [ 0.000000e+00, %4740 ], [ %11050, %._crit_edge1874 ] + %5592 = phi float [ 0.000000e+00, %4740 ], [ %11051, %._crit_edge1874 ] + %5593 = phi float [ 0.000000e+00, %4740 ], [ %11052, %._crit_edge1874 ] + %5594 = phi float [ 0.000000e+00, %4740 ], [ %11053, %._crit_edge1874 ] + %5595 = phi float [ 0.000000e+00, %4740 ], [ %11054, %._crit_edge1874 ] + %5596 = phi float [ 0.000000e+00, %4740 ], [ %11055, %._crit_edge1874 ] + %5597 = phi float [ 0.000000e+00, %4740 ], [ %11056, %._crit_edge1874 ] + %5598 = phi float [ 0.000000e+00, %4740 ], [ %11057, %._crit_edge1874 ] + %5599 = phi float [ 0.000000e+00, %4740 ], [ %11058, %._crit_edge1874 ] + %5600 = phi float [ 0.000000e+00, %4740 ], [ %11059, %._crit_edge1874 ] + %5601 = phi float [ 0.000000e+00, %4740 ], [ %11060, %._crit_edge1874 ] + %5602 = phi float [ 0.000000e+00, %4740 ], [ %11061, %._crit_edge1874 ] + %5603 = phi float [ 0.000000e+00, %4740 ], [ %11062, %._crit_edge1874 ] + %5604 = phi float [ 0.000000e+00, %4740 ], [ %11063, %._crit_edge1874 ] + %5605 = phi float [ 0.000000e+00, %4740 ], [ %11064, %._crit_edge1874 ] + %5606 = phi float [ 0.000000e+00, %4740 ], [ %11065, %._crit_edge1874 ] + %5607 = phi float [ 0.000000e+00, %4740 ], [ %11066, %._crit_edge1874 ] + %5608 = phi float [ 0.000000e+00, %4740 ], [ %11067, %._crit_edge1874 ] + %5609 = phi float [ 0.000000e+00, %4740 ], [ %11068, %._crit_edge1874 ] + %5610 = phi float [ 0.000000e+00, %4740 ], [ %11069, %._crit_edge1874 ] + %5611 = phi float [ 0.000000e+00, %4740 ], [ %11070, %._crit_edge1874 ] + %5612 = phi float [ 0.000000e+00, %4740 ], [ %11071, %._crit_edge1874 ] + %5613 = phi float [ 0.000000e+00, %4740 ], [ %11072, %._crit_edge1874 ] + %5614 = phi float [ 0.000000e+00, %4740 ], [ %11073, %._crit_edge1874 ] + %5615 = phi float [ 0.000000e+00, %4740 ], [ %11074, %._crit_edge1874 ] + %5616 = phi float [ 0.000000e+00, %4740 ], [ %11075, %._crit_edge1874 ] + %5617 = phi float [ 0.000000e+00, %4740 ], [ %11076, %._crit_edge1874 ] + %5618 = phi float [ 0.000000e+00, %4740 ], [ %11077, %._crit_edge1874 ] + %5619 = phi float [ 0.000000e+00, %4740 ], [ %11078, %._crit_edge1874 ] + %5620 = phi float [ 0.000000e+00, %4740 ], [ %11079, %._crit_edge1874 ] + %5621 = phi float [ 0.000000e+00, %4740 ], [ %11080, %._crit_edge1874 ] + %5622 = phi float [ 0.000000e+00, %4740 ], [ %11081, %._crit_edge1874 ] + %5623 = phi float [ 0.000000e+00, %4740 ], [ %11082, %._crit_edge1874 ] + %5624 = phi float [ 0.000000e+00, %4740 ], [ %11083, %._crit_edge1874 ] + %5625 = phi float [ 0.000000e+00, %4740 ], [ %11084, %._crit_edge1874 ] + %5626 = phi float [ 0.000000e+00, %4740 ], [ %11085, %._crit_edge1874 ] + %5627 = phi float [ 0.000000e+00, %4740 ], [ %11086, %._crit_edge1874 ] + %5628 = phi float [ 0.000000e+00, %4740 ], [ %11087, %._crit_edge1874 ] + %5629 = phi float [ 0.000000e+00, %4740 ], [ %11088, %._crit_edge1874 ] + %5630 = phi float [ 0.000000e+00, %4740 ], [ %11089, %._crit_edge1874 ] + %5631 = phi float [ 0.000000e+00, %4740 ], [ %11090, %._crit_edge1874 ] + %5632 = phi float [ 0.000000e+00, %4740 ], [ %11091, %._crit_edge1874 ] + %5633 = phi float [ 0.000000e+00, %4740 ], [ %11092, %._crit_edge1874 ] + %5634 = phi float [ 0.000000e+00, %4740 ], [ %11093, %._crit_edge1874 ] + %5635 = phi float [ 0.000000e+00, %4740 ], [ %11094, %._crit_edge1874 ] + %5636 = phi float [ 0.000000e+00, %4740 ], [ %11095, %._crit_edge1874 ] + %5637 = phi float [ 0.000000e+00, %4740 ], [ %11096, %._crit_edge1874 ] + %5638 = phi float [ 0.000000e+00, %4740 ], [ %11097, %._crit_edge1874 ] + %5639 = phi float [ 0.000000e+00, %4740 ], [ %11098, %._crit_edge1874 ] + %5640 = phi float [ 0.000000e+00, %4740 ], [ %11099, %._crit_edge1874 ] + %5641 = phi float [ 0.000000e+00, %4740 ], [ %11100, %._crit_edge1874 ] + %5642 = phi float [ 0.000000e+00, %4740 ], [ %11101, %._crit_edge1874 ] + %5643 = phi float [ 0.000000e+00, %4740 ], [ %11102, %._crit_edge1874 ] + %5644 = phi float [ 0.000000e+00, %4740 ], [ %11103, %._crit_edge1874 ] + %5645 = phi float [ 0.000000e+00, %4740 ], [ %11104, %._crit_edge1874 ] + %5646 = phi float [ 0.000000e+00, %4740 ], [ %11105, %._crit_edge1874 ] + %5647 = phi float [ 0.000000e+00, %4740 ], [ %11106, %._crit_edge1874 ] + %5648 = phi float [ 0.000000e+00, %4740 ], [ %11107, %._crit_edge1874 ] + %5649 = phi float [ 0.000000e+00, %4740 ], [ %11108, %._crit_edge1874 ] + %5650 = phi float [ 0.000000e+00, %4740 ], [ %10981, %._crit_edge1874 ] + %5651 = phi float [ 0.000000e+00, %4740 ], [ %10982, %._crit_edge1874 ] + %5652 = phi float [ 0.000000e+00, %4740 ], [ %10983, %._crit_edge1874 ] + %5653 = phi float [ 0.000000e+00, %4740 ], [ %10984, %._crit_edge1874 ] + %5654 = phi float [ 0.000000e+00, %4740 ], [ %10985, %._crit_edge1874 ] + %5655 = phi float [ 0.000000e+00, %4740 ], [ %10986, %._crit_edge1874 ] + %5656 = phi float [ 0.000000e+00, %4740 ], [ %10987, %._crit_edge1874 ] + %5657 = phi float [ 0.000000e+00, %4740 ], [ %10988, %._crit_edge1874 ] + %5658 = phi float [ 0.000000e+00, %4740 ], [ %10989, %._crit_edge1874 ] + %5659 = phi float [ 0.000000e+00, %4740 ], [ %10990, %._crit_edge1874 ] + %5660 = phi float [ 0.000000e+00, %4740 ], [ %10991, %._crit_edge1874 ] + %5661 = phi float [ 0.000000e+00, %4740 ], [ %10992, %._crit_edge1874 ] + %5662 = phi float [ 0.000000e+00, %4740 ], [ %10993, %._crit_edge1874 ] + %5663 = phi float [ 0.000000e+00, %4740 ], [ %10994, %._crit_edge1874 ] + %5664 = phi float [ 0.000000e+00, %4740 ], [ %10995, %._crit_edge1874 ] + %5665 = phi float [ 0.000000e+00, %4740 ], [ %10996, %._crit_edge1874 ] + %5666 = phi float [ 0.000000e+00, %4740 ], [ %10997, %._crit_edge1874 ] + %5667 = phi float [ 0.000000e+00, %4740 ], [ %10998, %._crit_edge1874 ] + %5668 = phi float [ 0.000000e+00, %4740 ], [ %10999, %._crit_edge1874 ] + %5669 = phi float [ 0.000000e+00, %4740 ], [ %11000, %._crit_edge1874 ] + %5670 = phi float [ 0.000000e+00, %4740 ], [ %11001, %._crit_edge1874 ] + %5671 = phi float [ 0.000000e+00, %4740 ], [ %11002, %._crit_edge1874 ] + %5672 = phi float [ 0.000000e+00, %4740 ], [ %11003, %._crit_edge1874 ] + %5673 = phi float [ 0.000000e+00, %4740 ], [ %11004, %._crit_edge1874 ] + %5674 = phi float [ 0.000000e+00, %4740 ], [ %11005, %._crit_edge1874 ] + %5675 = phi float [ 0.000000e+00, %4740 ], [ %11006, %._crit_edge1874 ] + %5676 = phi float [ 0.000000e+00, %4740 ], [ %11007, %._crit_edge1874 ] + %5677 = phi float [ 0.000000e+00, %4740 ], [ %11008, %._crit_edge1874 ] + %5678 = phi float [ 0.000000e+00, %4740 ], [ %11009, %._crit_edge1874 ] + %5679 = phi float [ 0.000000e+00, %4740 ], [ %11010, %._crit_edge1874 ] + %5680 = phi float [ 0.000000e+00, %4740 ], [ %11011, %._crit_edge1874 ] + %5681 = phi float [ 0.000000e+00, %4740 ], [ %11012, %._crit_edge1874 ] + %5682 = phi float [ 0.000000e+00, %4740 ], [ %11013, %._crit_edge1874 ] + %5683 = phi float [ 0.000000e+00, %4740 ], [ %11014, %._crit_edge1874 ] + %5684 = phi float [ 0.000000e+00, %4740 ], [ %11015, %._crit_edge1874 ] + %5685 = phi float [ 0.000000e+00, %4740 ], [ %11016, %._crit_edge1874 ] + %5686 = phi float [ 0.000000e+00, %4740 ], [ %11017, %._crit_edge1874 ] + %5687 = phi float [ 0.000000e+00, %4740 ], [ %11018, %._crit_edge1874 ] + %5688 = phi float [ 0.000000e+00, %4740 ], [ %11019, %._crit_edge1874 ] + %5689 = phi float [ 0.000000e+00, %4740 ], [ %11020, %._crit_edge1874 ] + %5690 = phi float [ 0.000000e+00, %4740 ], [ %11021, %._crit_edge1874 ] + %5691 = phi float [ 0.000000e+00, %4740 ], [ %11022, %._crit_edge1874 ] + %5692 = phi float [ 0.000000e+00, %4740 ], [ %11023, %._crit_edge1874 ] + %5693 = phi float [ 0.000000e+00, %4740 ], [ %11024, %._crit_edge1874 ] + %5694 = phi float [ 0.000000e+00, %4740 ], [ %11025, %._crit_edge1874 ] + %5695 = phi float [ 0.000000e+00, %4740 ], [ %11026, %._crit_edge1874 ] + %5696 = phi float [ 0.000000e+00, %4740 ], [ %11027, %._crit_edge1874 ] + %5697 = phi float [ 0.000000e+00, %4740 ], [ %11028, %._crit_edge1874 ] + %5698 = phi float [ 0.000000e+00, %4740 ], [ %11029, %._crit_edge1874 ] + %5699 = phi float [ 0.000000e+00, %4740 ], [ %11030, %._crit_edge1874 ] + %5700 = phi float [ 0.000000e+00, %4740 ], [ %11031, %._crit_edge1874 ] + %5701 = phi float [ 0.000000e+00, %4740 ], [ %11032, %._crit_edge1874 ] + %5702 = phi float [ 0.000000e+00, %4740 ], [ %11033, %._crit_edge1874 ] + %5703 = phi float [ 0.000000e+00, %4740 ], [ %11034, %._crit_edge1874 ] + %5704 = phi float [ 0.000000e+00, %4740 ], [ %11035, %._crit_edge1874 ] + %5705 = phi float [ 0.000000e+00, %4740 ], [ %11036, %._crit_edge1874 ] + %5706 = phi float [ 0.000000e+00, %4740 ], [ %11037, %._crit_edge1874 ] + %5707 = phi float [ 0.000000e+00, %4740 ], [ %11038, %._crit_edge1874 ] + %5708 = phi float [ 0.000000e+00, %4740 ], [ %11039, %._crit_edge1874 ] + %5709 = phi float [ 0.000000e+00, %4740 ], [ %11040, %._crit_edge1874 ] + %5710 = phi float [ 0.000000e+00, %4740 ], [ %11041, %._crit_edge1874 ] + %5711 = phi float [ 0.000000e+00, %4740 ], [ %11042, %._crit_edge1874 ] + %5712 = phi float [ 0.000000e+00, %4740 ], [ %11043, %._crit_edge1874 ] + %5713 = phi float [ 0.000000e+00, %4740 ], [ %11044, %._crit_edge1874 ] + %5714 = add nuw nsw i64 %indvars.iv, %5580, !dbg !252 + %.tr = trunc i64 %5714 to i32, !dbg !253 + %5715 = shl i32 %.tr, 7, !dbg !253 + %5716 = add i32 %5715, %4986, !dbg !253 + %5717 = sext i32 %5716 to i64, !dbg !254 + %5718 = trunc nuw nsw i64 %5714 to i32, !dbg !255 + %5719 = mul i32 %35, %5718, !dbg !255 + %5720 = add i32 %5719, %4987, !dbg !256 + %5721 = sext i32 %5720 to i64, !dbg !257 + %5722 = trunc i64 %5714 to i32, !dbg !258 + %5723 = add i32 %4988, %5722, !dbg !258 + %5724 = mul i32 %5723, %17, !dbg !258 + %5725 = sext i32 %5724 to i64, !dbg !259 + %5726 = getelementptr bfloat, ptr addrspace(1) %0, i64 %5717, !dbg !260 + %5727 = getelementptr bfloat, ptr addrspace(1) %5, i64 %5721, !dbg !261 + %5728 = getelementptr float, ptr addrspace(1) %3, i64 %5725, !dbg !262 + %5729 = getelementptr float, ptr addrspace(1) %4, i64 %5725, !dbg !263 + %5730 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5093, !dbg !264 + %5731 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5094, !dbg !264 + %5732 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5095, !dbg !264 + %5733 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5096, !dbg !264 + %5734 = getelementptr bfloat, ptr addrspace(1) %5730, i64 %4776, !dbg !265 + %5735 = getelementptr bfloat, ptr addrspace(1) %5731, i64 %4776, !dbg !265 + %5736 = getelementptr bfloat, ptr addrspace(1) %5732, i64 %4776, !dbg !265 + %5737 = getelementptr bfloat, ptr addrspace(1) %5733, i64 %4776, !dbg !265 + %5738 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5097, !dbg !266 + %5739 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5098, !dbg !266 + %5740 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5099, !dbg !266 + %5741 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5100, !dbg !266 + %5742 = getelementptr bfloat, ptr addrspace(1) %5738, i64 %4776, !dbg !267 + %5743 = getelementptr bfloat, ptr addrspace(1) %5739, i64 %4776, !dbg !267 + %5744 = getelementptr bfloat, ptr addrspace(1) %5740, i64 %4776, !dbg !267 + %5745 = getelementptr bfloat, ptr addrspace(1) %5741, i64 %4776, !dbg !267 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5112, ptr addrspace(1) %5734, i32 %5113) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5115, ptr addrspace(1) %5735, i32 %5116) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5118, ptr addrspace(1) %5736, i32 %5119) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5121, ptr addrspace(1) %5737, i32 %5122) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + %5746 = getelementptr float, ptr addrspace(1) %5728, i64 %5155, !dbg !269 + %5747 = getelementptr float, ptr addrspace(1) %5728, i64 %5156, !dbg !269 + %5748 = getelementptr float, ptr addrspace(1) %5728, i64 %5157, !dbg !269 + %5749 = getelementptr float, ptr addrspace(1) %5728, i64 %5158, !dbg !269 + %5750 = getelementptr float, ptr addrspace(1) %5728, i64 %5159, !dbg !269 + %5751 = getelementptr float, ptr addrspace(1) %5728, i64 %5160, !dbg !269 + %5752 = getelementptr float, ptr addrspace(1) %5728, i64 %5161, !dbg !269 + %5753 = getelementptr float, ptr addrspace(1) %5728, i64 %5162, !dbg !269 + %5754 = getelementptr float, ptr addrspace(1) %5728, i64 %5163, !dbg !269 + %5755 = getelementptr float, ptr addrspace(1) %5728, i64 %5164, !dbg !269 + %5756 = getelementptr float, ptr addrspace(1) %5728, i64 %5165, !dbg !269 + %5757 = getelementptr float, ptr addrspace(1) %5728, i64 %5166, !dbg !269 + %5758 = getelementptr float, ptr addrspace(1) %5728, i64 %5167, !dbg !269 + %5759 = getelementptr float, ptr addrspace(1) %5728, i64 %5168, !dbg !269 + %5760 = getelementptr float, ptr addrspace(1) %5728, i64 %5169, !dbg !269 + %5761 = getelementptr float, ptr addrspace(1) %5728, i64 %5170, !dbg !269 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5190, ptr addrspace(1) %5746, i32 %5191, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5193, ptr addrspace(1) %5747, i32 %5194, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5196, ptr addrspace(1) %5748, i32 %5197, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5199, ptr addrspace(1) %5749, i32 %5200, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5202, ptr addrspace(1) %5750, i32 %5203, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5205, ptr addrspace(1) %5751, i32 %5206, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5208, ptr addrspace(1) %5752, i32 %5209, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5211, ptr addrspace(1) %5753, i32 %5212, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5214, ptr addrspace(1) %5754, i32 %5215, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5217, ptr addrspace(1) %5755, i32 %5218, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5220, ptr addrspace(1) %5756, i32 %5221, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5223, ptr addrspace(1) %5757, i32 %5224, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5226, ptr addrspace(1) %5758, i32 %5227, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5229, ptr addrspace(1) %5759, i32 %5230, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5232, ptr addrspace(1) %5760, i32 %5233, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5235, ptr addrspace(1) %5761, i32 %5236, i1 %5188) #3, !dbg !270 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !270 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5237, ptr addrspace(1) %5742, i32 %5113) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5238, ptr addrspace(1) %5743, i32 %5116) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5239, ptr addrspace(1) %5744, i32 %5119) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5240, ptr addrspace(1) %5745, i32 %5122) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + %5762 = getelementptr float, ptr addrspace(1) %5729, i64 %5155, !dbg !272 + %5763 = getelementptr float, ptr addrspace(1) %5729, i64 %5156, !dbg !272 + %5764 = getelementptr float, ptr addrspace(1) %5729, i64 %5157, !dbg !272 + %5765 = getelementptr float, ptr addrspace(1) %5729, i64 %5158, !dbg !272 + %5766 = getelementptr float, ptr addrspace(1) %5729, i64 %5159, !dbg !272 + %5767 = getelementptr float, ptr addrspace(1) %5729, i64 %5160, !dbg !272 + %5768 = getelementptr float, ptr addrspace(1) %5729, i64 %5161, !dbg !272 + %5769 = getelementptr float, ptr addrspace(1) %5729, i64 %5162, !dbg !272 + %5770 = getelementptr float, ptr addrspace(1) %5729, i64 %5163, !dbg !272 + %5771 = getelementptr float, ptr addrspace(1) %5729, i64 %5164, !dbg !272 + %5772 = getelementptr float, ptr addrspace(1) %5729, i64 %5165, !dbg !272 + %5773 = getelementptr float, ptr addrspace(1) %5729, i64 %5166, !dbg !272 + %5774 = getelementptr float, ptr addrspace(1) %5729, i64 %5167, !dbg !272 + %5775 = getelementptr float, ptr addrspace(1) %5729, i64 %5168, !dbg !272 + %5776 = getelementptr float, ptr addrspace(1) %5729, i64 %5169, !dbg !272 + %5777 = getelementptr float, ptr addrspace(1) %5729, i64 %5170, !dbg !272 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5241, ptr addrspace(1) %5762, i32 %5191, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5242, ptr addrspace(1) %5763, i32 %5194, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5243, ptr addrspace(1) %5764, i32 %5197, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5244, ptr addrspace(1) %5765, i32 %5200, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5245, ptr addrspace(1) %5766, i32 %5203, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5246, ptr addrspace(1) %5767, i32 %5206, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5247, ptr addrspace(1) %5768, i32 %5209, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5248, ptr addrspace(1) %5769, i32 %5212, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5249, ptr addrspace(1) %5770, i32 %5215, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5250, ptr addrspace(1) %5771, i32 %5218, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5251, ptr addrspace(1) %5772, i32 %5221, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5252, ptr addrspace(1) %5773, i32 %5224, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5253, ptr addrspace(1) %5774, i32 %5227, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5254, ptr addrspace(1) %5775, i32 %5230, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5255, ptr addrspace(1) %5776, i32 %5233, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5256, ptr addrspace(1) %5777, i32 %5236, i1 %5188) #3, !dbg !273 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !273 + %5778 = getelementptr i8, ptr addrspace(1) %5734, i64 524288, !dbg !274 + %5779 = getelementptr i8, ptr addrspace(1) %5735, i64 524288, !dbg !274 + %5780 = getelementptr i8, ptr addrspace(1) %5736, i64 524288, !dbg !274 + %5781 = getelementptr i8, ptr addrspace(1) %5737, i64 524288, !dbg !274 + %5782 = getelementptr i8, ptr addrspace(1) %5742, i64 16384, !dbg !275 + %5783 = getelementptr i8, ptr addrspace(1) %5743, i64 16384, !dbg !275 + %5784 = getelementptr i8, ptr addrspace(1) %5744, i64 16384, !dbg !275 + %5785 = getelementptr i8, ptr addrspace(1) %5745, i64 16384, !dbg !275 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5286, ptr addrspace(1) %5778, i32 %5287) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5288, ptr addrspace(1) %5779, i32 %5289) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5290, ptr addrspace(1) %5780, i32 %5291) #3, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5292, ptr addrspace(1) %5781, i32 %5293) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + %5786 = getelementptr float, ptr addrspace(1) %5728, i64 %5310, !dbg !269 + %5787 = getelementptr float, ptr addrspace(1) %5728, i64 %5311, !dbg !269 + %5788 = getelementptr float, ptr addrspace(1) %5728, i64 %5312, !dbg !269 + %5789 = getelementptr float, ptr addrspace(1) %5728, i64 %5313, !dbg !269 + %5790 = getelementptr float, ptr addrspace(1) %5728, i64 %5314, !dbg !269 + %5791 = getelementptr float, ptr addrspace(1) %5728, i64 %5315, !dbg !269 + %5792 = getelementptr float, ptr addrspace(1) %5728, i64 %5316, !dbg !269 + %5793 = getelementptr float, ptr addrspace(1) %5728, i64 %5317, !dbg !269 + %5794 = getelementptr float, ptr addrspace(1) %5728, i64 %5318, !dbg !269 + %5795 = getelementptr float, ptr addrspace(1) %5728, i64 %5319, !dbg !269 + %5796 = getelementptr float, ptr addrspace(1) %5728, i64 %5320, !dbg !269 + %5797 = getelementptr float, ptr addrspace(1) %5728, i64 %5321, !dbg !269 + %5798 = getelementptr float, ptr addrspace(1) %5728, i64 %5322, !dbg !269 + %5799 = getelementptr float, ptr addrspace(1) %5728, i64 %5323, !dbg !269 + %5800 = getelementptr float, ptr addrspace(1) %5728, i64 %5324, !dbg !269 + %5801 = getelementptr float, ptr addrspace(1) %5728, i64 %5325, !dbg !269 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5342, ptr addrspace(1) %5786, i32 %5343, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5344, ptr addrspace(1) %5787, i32 %5345, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5346, ptr addrspace(1) %5788, i32 %5347, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5348, ptr addrspace(1) %5789, i32 %5349, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5350, ptr addrspace(1) %5790, i32 %5351, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5352, ptr addrspace(1) %5791, i32 %5353, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5354, ptr addrspace(1) %5792, i32 %5355, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5356, ptr addrspace(1) %5793, i32 %5357, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5358, ptr addrspace(1) %5794, i32 %5359, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5360, ptr addrspace(1) %5795, i32 %5361, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5362, ptr addrspace(1) %5796, i32 %5363, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5364, ptr addrspace(1) %5797, i32 %5365, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5366, ptr addrspace(1) %5798, i32 %5367, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5368, ptr addrspace(1) %5799, i32 %5369, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5370, ptr addrspace(1) %5800, i32 %5371, i1 %5188) #3, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5372, ptr addrspace(1) %5801, i32 %5373, i1 %5188) #3, !dbg !270 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !270 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5374, ptr addrspace(1) %5782, i32 %5287) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5375, ptr addrspace(1) %5783, i32 %5289) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5376, ptr addrspace(1) %5784, i32 %5291) #3, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5377, ptr addrspace(1) %5785, i32 %5293) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + %5802 = getelementptr float, ptr addrspace(1) %5729, i64 %5310, !dbg !272 + %5803 = getelementptr float, ptr addrspace(1) %5729, i64 %5311, !dbg !272 + %5804 = getelementptr float, ptr addrspace(1) %5729, i64 %5312, !dbg !272 + %5805 = getelementptr float, ptr addrspace(1) %5729, i64 %5313, !dbg !272 + %5806 = getelementptr float, ptr addrspace(1) %5729, i64 %5314, !dbg !272 + %5807 = getelementptr float, ptr addrspace(1) %5729, i64 %5315, !dbg !272 + %5808 = getelementptr float, ptr addrspace(1) %5729, i64 %5316, !dbg !272 + %5809 = getelementptr float, ptr addrspace(1) %5729, i64 %5317, !dbg !272 + %5810 = getelementptr float, ptr addrspace(1) %5729, i64 %5318, !dbg !272 + %5811 = getelementptr float, ptr addrspace(1) %5729, i64 %5319, !dbg !272 + %5812 = getelementptr float, ptr addrspace(1) %5729, i64 %5320, !dbg !272 + %5813 = getelementptr float, ptr addrspace(1) %5729, i64 %5321, !dbg !272 + %5814 = getelementptr float, ptr addrspace(1) %5729, i64 %5322, !dbg !272 + %5815 = getelementptr float, ptr addrspace(1) %5729, i64 %5323, !dbg !272 + %5816 = getelementptr float, ptr addrspace(1) %5729, i64 %5324, !dbg !272 + %5817 = getelementptr float, ptr addrspace(1) %5729, i64 %5325, !dbg !272 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5378, ptr addrspace(1) %5802, i32 %5343, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5379, ptr addrspace(1) %5803, i32 %5345, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5380, ptr addrspace(1) %5804, i32 %5347, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5381, ptr addrspace(1) %5805, i32 %5349, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5382, ptr addrspace(1) %5806, i32 %5351, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5383, ptr addrspace(1) %5807, i32 %5353, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5384, ptr addrspace(1) %5808, i32 %5355, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5385, ptr addrspace(1) %5809, i32 %5357, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5386, ptr addrspace(1) %5810, i32 %5359, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5387, ptr addrspace(1) %5811, i32 %5361, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5388, ptr addrspace(1) %5812, i32 %5363, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5389, ptr addrspace(1) %5813, i32 %5365, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5390, ptr addrspace(1) %5814, i32 %5367, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5391, ptr addrspace(1) %5815, i32 %5369, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5392, ptr addrspace(1) %5816, i32 %5371, i1 %5188) #3, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5393, ptr addrspace(1) %5817, i32 %5373, i1 %5188) #3, !dbg !273 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !273 + br i1 %5101, label %.lr.ph1700, label %._crit_edge1701, !dbg !276 + +.lr.ph1700: ; preds = %5585, %__nv_exp2f.exit1417 + %5818 = phi i32 [ %8217, %__nv_exp2f.exit1417 ], [ 64, %5585 ] + %5819 = phi i32 [ %.pn2191683, %__nv_exp2f.exit1417 ], [ %5123, %5585 ] + %5820 = phi i32 [ %.pn2171684, %__nv_exp2f.exit1417 ], [ %5125, %5585 ] + %5821 = phi i32 [ %.pn2151685, %__nv_exp2f.exit1417 ], [ %5127, %5585 ] + %5822 = phi i32 [ %.pn2131686, %__nv_exp2f.exit1417 ], [ %5129, %5585 ] + %5823 = phi i32 [ %.pn2111687, %__nv_exp2f.exit1417 ], [ %5131, %5585 ] + %5824 = phi i32 [ %.pn2091688, %__nv_exp2f.exit1417 ], [ %5133, %5585 ] + %5825 = phi i32 [ %.pn2071689, %__nv_exp2f.exit1417 ], [ %5135, %5585 ] + %5826 = phi i32 [ %.pn2051690, %__nv_exp2f.exit1417 ], [ %5137, %5585 ] + %5827 = phi i32 [ %.pn2031691, %__nv_exp2f.exit1417 ], [ %5139, %5585 ] + %5828 = phi i32 [ %.pn2011692, %__nv_exp2f.exit1417 ], [ %5141, %5585 ] + %5829 = phi i32 [ %.pn1991693, %__nv_exp2f.exit1417 ], [ %5143, %5585 ] + %5830 = phi i32 [ %.pn1971694, %__nv_exp2f.exit1417 ], [ %5145, %5585 ] + %5831 = phi i32 [ %.pn1951695, %__nv_exp2f.exit1417 ], [ %5147, %5585 ] + %5832 = phi i32 [ %.pn1931696, %__nv_exp2f.exit1417 ], [ %5149, %5585 ] + %5833 = phi i32 [ %.pn1911697, %__nv_exp2f.exit1417 ], [ %5151, %5585 ] + %5834 = phi i32 [ %.pn1891698, %__nv_exp2f.exit1417 ], [ %5153, %5585 ] + %5835 = phi i32 [ %5981, %__nv_exp2f.exit1417 ], [ -1, %5585 ] + %5836 = phi i32 [ %8256, %__nv_exp2f.exit1417 ], [ 1, %5585 ] + %5837 = phi i32 [ %5984, %__nv_exp2f.exit1417 ], [ -1, %5585 ] + %5838 = phi i32 [ %8259, %__nv_exp2f.exit1417 ], [ 1, %5585 ] + %.pn1891698 = phi i32 [ %8245, %__nv_exp2f.exit1417 ], [ %5273, %5585 ] + %.pn1911697 = phi i32 [ %8244, %__nv_exp2f.exit1417 ], [ %5272, %5585 ] + %.pn1931696 = phi i32 [ %8243, %__nv_exp2f.exit1417 ], [ %5271, %5585 ] + %.pn1951695 = phi i32 [ %8242, %__nv_exp2f.exit1417 ], [ %5270, %5585 ] + %.pn1971694 = phi i32 [ %8241, %__nv_exp2f.exit1417 ], [ %5269, %5585 ] + %.pn1991693 = phi i32 [ %8240, %__nv_exp2f.exit1417 ], [ %5268, %5585 ] + %.pn2011692 = phi i32 [ %8239, %__nv_exp2f.exit1417 ], [ %5267, %5585 ] + %.pn2031691 = phi i32 [ %8238, %__nv_exp2f.exit1417 ], [ %5266, %5585 ] + %.pn2051690 = phi i32 [ %8237, %__nv_exp2f.exit1417 ], [ %5265, %5585 ] + %.pn2071689 = phi i32 [ %8236, %__nv_exp2f.exit1417 ], [ %5264, %5585 ] + %.pn2091688 = phi i32 [ %8235, %__nv_exp2f.exit1417 ], [ %5263, %5585 ] + %.pn2111687 = phi i32 [ %8234, %__nv_exp2f.exit1417 ], [ %5262, %5585 ] + %.pn2131686 = phi i32 [ %8233, %__nv_exp2f.exit1417 ], [ %5261, %5585 ] + %.pn2151685 = phi i32 [ %8232, %__nv_exp2f.exit1417 ], [ %5260, %5585 ] + %.pn2171684 = phi i32 [ %8231, %__nv_exp2f.exit1417 ], [ %5259, %5585 ] + %.pn2191683 = phi i32 [ %8230, %__nv_exp2f.exit1417 ], [ %5258, %5585 ] + %5839 = phi i32 [ %8250, %__nv_exp2f.exit1417 ], [ %5274, %5585 ] + %5840 = phi i32 [ %8251, %__nv_exp2f.exit1417 ], [ %5275, %5585 ] + %5841 = phi i32 [ %8252, %__nv_exp2f.exit1417 ], [ %5276, %5585 ] + %5842 = phi i32 [ %8253, %__nv_exp2f.exit1417 ], [ %5277, %5585 ] + %.pn1391682 = phi ptr addrspace(1) [ %8229, %__nv_exp2f.exit1417 ], [ %5785, %5585 ] + %.pn1551681 = phi ptr addrspace(1) [ %8228, %__nv_exp2f.exit1417 ], [ %5784, %5585 ] + %.pn1711680 = phi ptr addrspace(1) [ %8227, %__nv_exp2f.exit1417 ], [ %5783, %5585 ] + %.pn1871679 = phi ptr addrspace(1) [ %8226, %__nv_exp2f.exit1417 ], [ %5782, %5585 ] + %5843 = phi i32 [ %8246, %__nv_exp2f.exit1417 ], [ %5274, %5585 ] + %5844 = phi i32 [ %8247, %__nv_exp2f.exit1417 ], [ %5275, %5585 ] + %5845 = phi i32 [ %8248, %__nv_exp2f.exit1417 ], [ %5276, %5585 ] + %5846 = phi i32 [ %8249, %__nv_exp2f.exit1417 ], [ %5277, %5585 ] + %.pn751678 = phi ptr addrspace(1) [ %8223, %__nv_exp2f.exit1417 ], [ %5781, %5585 ] + %.pn911677 = phi ptr addrspace(1) [ %8222, %__nv_exp2f.exit1417 ], [ %5780, %5585 ] + %.pn1071676 = phi ptr addrspace(1) [ %8221, %__nv_exp2f.exit1417 ], [ %5779, %5585 ] + %.pn1231675 = phi ptr addrspace(1) [ %8220, %__nv_exp2f.exit1417 ], [ %5778, %5585 ] + %5847 = phi float [ %7275, %__nv_exp2f.exit1417 ], [ %5650, %5585 ] + %5848 = phi float [ %7276, %__nv_exp2f.exit1417 ], [ %5651, %5585 ] + %5849 = phi float [ %7277, %__nv_exp2f.exit1417 ], [ %5652, %5585 ] + %5850 = phi float [ %7278, %__nv_exp2f.exit1417 ], [ %5653, %5585 ] + %5851 = phi float [ %7279, %__nv_exp2f.exit1417 ], [ %5654, %5585 ] + %5852 = phi float [ %7280, %__nv_exp2f.exit1417 ], [ %5655, %5585 ] + %5853 = phi float [ %7281, %__nv_exp2f.exit1417 ], [ %5656, %5585 ] + %5854 = phi float [ %7282, %__nv_exp2f.exit1417 ], [ %5657, %5585 ] + %5855 = phi float [ %7283, %__nv_exp2f.exit1417 ], [ %5658, %5585 ] + %5856 = phi float [ %7284, %__nv_exp2f.exit1417 ], [ %5659, %5585 ] + %5857 = phi float [ %7285, %__nv_exp2f.exit1417 ], [ %5660, %5585 ] + %5858 = phi float [ %7286, %__nv_exp2f.exit1417 ], [ %5661, %5585 ] + %5859 = phi float [ %7287, %__nv_exp2f.exit1417 ], [ %5662, %5585 ] + %5860 = phi float [ %7288, %__nv_exp2f.exit1417 ], [ %5663, %5585 ] + %5861 = phi float [ %7289, %__nv_exp2f.exit1417 ], [ %5664, %5585 ] + %5862 = phi float [ %7290, %__nv_exp2f.exit1417 ], [ %5665, %5585 ] + %5863 = phi float [ %7291, %__nv_exp2f.exit1417 ], [ %5666, %5585 ] + %5864 = phi float [ %7292, %__nv_exp2f.exit1417 ], [ %5667, %5585 ] + %5865 = phi float [ %7293, %__nv_exp2f.exit1417 ], [ %5668, %5585 ] + %5866 = phi float [ %7294, %__nv_exp2f.exit1417 ], [ %5669, %5585 ] + %5867 = phi float [ %7295, %__nv_exp2f.exit1417 ], [ %5670, %5585 ] + %5868 = phi float [ %7296, %__nv_exp2f.exit1417 ], [ %5671, %5585 ] + %5869 = phi float [ %7297, %__nv_exp2f.exit1417 ], [ %5672, %5585 ] + %5870 = phi float [ %7298, %__nv_exp2f.exit1417 ], [ %5673, %5585 ] + %5871 = phi float [ %7299, %__nv_exp2f.exit1417 ], [ %5674, %5585 ] + %5872 = phi float [ %7300, %__nv_exp2f.exit1417 ], [ %5675, %5585 ] + %5873 = phi float [ %7301, %__nv_exp2f.exit1417 ], [ %5676, %5585 ] + %5874 = phi float [ %7302, %__nv_exp2f.exit1417 ], [ %5677, %5585 ] + %5875 = phi float [ %7303, %__nv_exp2f.exit1417 ], [ %5678, %5585 ] + %5876 = phi float [ %7304, %__nv_exp2f.exit1417 ], [ %5679, %5585 ] + %5877 = phi float [ %7305, %__nv_exp2f.exit1417 ], [ %5680, %5585 ] + %5878 = phi float [ %7306, %__nv_exp2f.exit1417 ], [ %5681, %5585 ] + %5879 = phi float [ %7307, %__nv_exp2f.exit1417 ], [ %5682, %5585 ] + %5880 = phi float [ %7308, %__nv_exp2f.exit1417 ], [ %5683, %5585 ] + %5881 = phi float [ %7309, %__nv_exp2f.exit1417 ], [ %5684, %5585 ] + %5882 = phi float [ %7310, %__nv_exp2f.exit1417 ], [ %5685, %5585 ] + %5883 = phi float [ %7311, %__nv_exp2f.exit1417 ], [ %5686, %5585 ] + %5884 = phi float [ %7312, %__nv_exp2f.exit1417 ], [ %5687, %5585 ] + %5885 = phi float [ %7313, %__nv_exp2f.exit1417 ], [ %5688, %5585 ] + %5886 = phi float [ %7314, %__nv_exp2f.exit1417 ], [ %5689, %5585 ] + %5887 = phi float [ %7315, %__nv_exp2f.exit1417 ], [ %5690, %5585 ] + %5888 = phi float [ %7316, %__nv_exp2f.exit1417 ], [ %5691, %5585 ] + %5889 = phi float [ %7317, %__nv_exp2f.exit1417 ], [ %5692, %5585 ] + %5890 = phi float [ %7318, %__nv_exp2f.exit1417 ], [ %5693, %5585 ] + %5891 = phi float [ %7319, %__nv_exp2f.exit1417 ], [ %5694, %5585 ] + %5892 = phi float [ %7320, %__nv_exp2f.exit1417 ], [ %5695, %5585 ] + %5893 = phi float [ %7321, %__nv_exp2f.exit1417 ], [ %5696, %5585 ] + %5894 = phi float [ %7322, %__nv_exp2f.exit1417 ], [ %5697, %5585 ] + %5895 = phi float [ %7323, %__nv_exp2f.exit1417 ], [ %5698, %5585 ] + %5896 = phi float [ %7324, %__nv_exp2f.exit1417 ], [ %5699, %5585 ] + %5897 = phi float [ %7325, %__nv_exp2f.exit1417 ], [ %5700, %5585 ] + %5898 = phi float [ %7326, %__nv_exp2f.exit1417 ], [ %5701, %5585 ] + %5899 = phi float [ %7327, %__nv_exp2f.exit1417 ], [ %5702, %5585 ] + %5900 = phi float [ %7328, %__nv_exp2f.exit1417 ], [ %5703, %5585 ] + %5901 = phi float [ %7329, %__nv_exp2f.exit1417 ], [ %5704, %5585 ] + %5902 = phi float [ %7330, %__nv_exp2f.exit1417 ], [ %5705, %5585 ] + %5903 = phi float [ %7331, %__nv_exp2f.exit1417 ], [ %5706, %5585 ] + %5904 = phi float [ %7332, %__nv_exp2f.exit1417 ], [ %5707, %5585 ] + %5905 = phi float [ %7333, %__nv_exp2f.exit1417 ], [ %5708, %5585 ] + %5906 = phi float [ %7334, %__nv_exp2f.exit1417 ], [ %5709, %5585 ] + %5907 = phi float [ %7335, %__nv_exp2f.exit1417 ], [ %5710, %5585 ] + %5908 = phi float [ %7336, %__nv_exp2f.exit1417 ], [ %5711, %5585 ] + %5909 = phi float [ %7337, %__nv_exp2f.exit1417 ], [ %5712, %5585 ] + %5910 = phi float [ %7338, %__nv_exp2f.exit1417 ], [ %5713, %5585 ] + %5911 = phi float [ %8131, %__nv_exp2f.exit1417 ], [ %5586, %5585 ] + %5912 = phi float [ %8132, %__nv_exp2f.exit1417 ], [ %5587, %5585 ] + %5913 = phi float [ %8133, %__nv_exp2f.exit1417 ], [ %5588, %5585 ] + %5914 = phi float [ %8134, %__nv_exp2f.exit1417 ], [ %5589, %5585 ] + %5915 = phi float [ %8135, %__nv_exp2f.exit1417 ], [ %5590, %5585 ] + %5916 = phi float [ %8136, %__nv_exp2f.exit1417 ], [ %5591, %5585 ] + %5917 = phi float [ %8137, %__nv_exp2f.exit1417 ], [ %5592, %5585 ] + %5918 = phi float [ %8138, %__nv_exp2f.exit1417 ], [ %5593, %5585 ] + %5919 = phi float [ %8139, %__nv_exp2f.exit1417 ], [ %5594, %5585 ] + %5920 = phi float [ %8140, %__nv_exp2f.exit1417 ], [ %5595, %5585 ] + %5921 = phi float [ %8141, %__nv_exp2f.exit1417 ], [ %5596, %5585 ] + %5922 = phi float [ %8142, %__nv_exp2f.exit1417 ], [ %5597, %5585 ] + %5923 = phi float [ %8143, %__nv_exp2f.exit1417 ], [ %5598, %5585 ] + %5924 = phi float [ %8144, %__nv_exp2f.exit1417 ], [ %5599, %5585 ] + %5925 = phi float [ %8145, %__nv_exp2f.exit1417 ], [ %5600, %5585 ] + %5926 = phi float [ %8146, %__nv_exp2f.exit1417 ], [ %5601, %5585 ] + %5927 = phi float [ %8147, %__nv_exp2f.exit1417 ], [ %5602, %5585 ] + %5928 = phi float [ %8148, %__nv_exp2f.exit1417 ], [ %5603, %5585 ] + %5929 = phi float [ %8149, %__nv_exp2f.exit1417 ], [ %5604, %5585 ] + %5930 = phi float [ %8150, %__nv_exp2f.exit1417 ], [ %5605, %5585 ] + %5931 = phi float [ %8151, %__nv_exp2f.exit1417 ], [ %5606, %5585 ] + %5932 = phi float [ %8152, %__nv_exp2f.exit1417 ], [ %5607, %5585 ] + %5933 = phi float [ %8153, %__nv_exp2f.exit1417 ], [ %5608, %5585 ] + %5934 = phi float [ %8154, %__nv_exp2f.exit1417 ], [ %5609, %5585 ] + %5935 = phi float [ %8155, %__nv_exp2f.exit1417 ], [ %5610, %5585 ] + %5936 = phi float [ %8156, %__nv_exp2f.exit1417 ], [ %5611, %5585 ] + %5937 = phi float [ %8157, %__nv_exp2f.exit1417 ], [ %5612, %5585 ] + %5938 = phi float [ %8158, %__nv_exp2f.exit1417 ], [ %5613, %5585 ] + %5939 = phi float [ %8159, %__nv_exp2f.exit1417 ], [ %5614, %5585 ] + %5940 = phi float [ %8160, %__nv_exp2f.exit1417 ], [ %5615, %5585 ] + %5941 = phi float [ %8161, %__nv_exp2f.exit1417 ], [ %5616, %5585 ] + %5942 = phi float [ %8162, %__nv_exp2f.exit1417 ], [ %5617, %5585 ] + %5943 = phi float [ %8163, %__nv_exp2f.exit1417 ], [ %5618, %5585 ] + %5944 = phi float [ %8164, %__nv_exp2f.exit1417 ], [ %5619, %5585 ] + %5945 = phi float [ %8165, %__nv_exp2f.exit1417 ], [ %5620, %5585 ] + %5946 = phi float [ %8166, %__nv_exp2f.exit1417 ], [ %5621, %5585 ] + %5947 = phi float [ %8167, %__nv_exp2f.exit1417 ], [ %5622, %5585 ] + %5948 = phi float [ %8168, %__nv_exp2f.exit1417 ], [ %5623, %5585 ] + %5949 = phi float [ %8169, %__nv_exp2f.exit1417 ], [ %5624, %5585 ] + %5950 = phi float [ %8170, %__nv_exp2f.exit1417 ], [ %5625, %5585 ] + %5951 = phi float [ %8171, %__nv_exp2f.exit1417 ], [ %5626, %5585 ] + %5952 = phi float [ %8172, %__nv_exp2f.exit1417 ], [ %5627, %5585 ] + %5953 = phi float [ %8173, %__nv_exp2f.exit1417 ], [ %5628, %5585 ] + %5954 = phi float [ %8174, %__nv_exp2f.exit1417 ], [ %5629, %5585 ] + %5955 = phi float [ %8175, %__nv_exp2f.exit1417 ], [ %5630, %5585 ] + %5956 = phi float [ %8176, %__nv_exp2f.exit1417 ], [ %5631, %5585 ] + %5957 = phi float [ %8177, %__nv_exp2f.exit1417 ], [ %5632, %5585 ] + %5958 = phi float [ %8178, %__nv_exp2f.exit1417 ], [ %5633, %5585 ] + %5959 = phi float [ %8179, %__nv_exp2f.exit1417 ], [ %5634, %5585 ] + %5960 = phi float [ %8180, %__nv_exp2f.exit1417 ], [ %5635, %5585 ] + %5961 = phi float [ %8181, %__nv_exp2f.exit1417 ], [ %5636, %5585 ] + %5962 = phi float [ %8182, %__nv_exp2f.exit1417 ], [ %5637, %5585 ] + %5963 = phi float [ %8183, %__nv_exp2f.exit1417 ], [ %5638, %5585 ] + %5964 = phi float [ %8184, %__nv_exp2f.exit1417 ], [ %5639, %5585 ] + %5965 = phi float [ %8185, %__nv_exp2f.exit1417 ], [ %5640, %5585 ] + %5966 = phi float [ %8186, %__nv_exp2f.exit1417 ], [ %5641, %5585 ] + %5967 = phi float [ %8187, %__nv_exp2f.exit1417 ], [ %5642, %5585 ] + %5968 = phi float [ %8188, %__nv_exp2f.exit1417 ], [ %5643, %5585 ] + %5969 = phi float [ %8189, %__nv_exp2f.exit1417 ], [ %5644, %5585 ] + %5970 = phi float [ %8190, %__nv_exp2f.exit1417 ], [ %5645, %5585 ] + %5971 = phi float [ %8191, %__nv_exp2f.exit1417 ], [ %5646, %5585 ] + %5972 = phi float [ %8192, %__nv_exp2f.exit1417 ], [ %5647, %5585 ] + %5973 = phi float [ %8193, %__nv_exp2f.exit1417 ], [ %5648, %5585 ] + %5974 = phi float [ %8194, %__nv_exp2f.exit1417 ], [ %5649, %5585 ] + %5975 = phi i32 [ %8198, %__nv_exp2f.exit1417 ], [ 0, %5585 ] + %5976 = phi <16 x i32> [ %8197, %__nv_exp2f.exit1417 ], [ %5040, %5585 ] + %5977 = icmp slt i32 %5975, %5394, !dbg !276 + %5978 = icmp slt i32 %5975, %5395, !dbg !276 + %5979 = add i32 %5835, 1, !dbg !276 + %5980 = icmp sgt i32 %5979, 1, !dbg !276 + %5981 = select i1 %5980, i32 0, i32 %5979, !dbg !276 + %5982 = add i32 %5837, 1, !dbg !276 + %5983 = icmp sgt i32 %5982, 2, !dbg !276 + %5984 = select i1 %5983, i32 0, i32 %5982, !dbg !276 + %5985 = icmp slt i32 %5819, %17, !dbg !277 + %5986 = icmp slt i32 %5820, %17, !dbg !277 + %5987 = icmp slt i32 %5821, %17, !dbg !277 + %5988 = icmp slt i32 %5822, %17, !dbg !277 + %5989 = icmp slt i32 %5823, %17, !dbg !277 + %5990 = icmp slt i32 %5824, %17, !dbg !277 + %5991 = icmp slt i32 %5825, %17, !dbg !277 + %5992 = icmp slt i32 %5826, %17, !dbg !277 + %5993 = icmp slt i32 %5827, %17, !dbg !277 + %5994 = icmp slt i32 %5828, %17, !dbg !277 + %5995 = icmp slt i32 %5829, %17, !dbg !277 + %5996 = icmp slt i32 %5830, %17, !dbg !277 + %5997 = icmp slt i32 %5831, %17, !dbg !277 + %5998 = icmp slt i32 %5832, %17, !dbg !277 + %5999 = icmp slt i32 %5833, %17, !dbg !277 + %6000 = icmp slt i32 %5834, %17, !dbg !277 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !268 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !268 + %6001 = shl i32 %5984, 13, !dbg !268 + %6002 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %6001, !dbg !268 + %6003 = shl i32 %5981, 6, !dbg !270 + %6004 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %6003, !dbg !270 + %6005 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5189, !dbg !270 + %6006 = load float, ptr addrspace(3) %6005, align 8, !dbg !270 + %6007 = getelementptr inbounds nuw i8, ptr addrspace(3) %6005, i32 4, !dbg !270 + %6008 = load float, ptr addrspace(3) %6007, align 4, !dbg !270 + %6009 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5195, !dbg !270 + %6010 = load float, ptr addrspace(3) %6009, align 8, !dbg !270 + %6011 = getelementptr inbounds nuw i8, ptr addrspace(3) %6009, i32 4, !dbg !270 + %6012 = load float, ptr addrspace(3) %6011, align 4, !dbg !270 + %6013 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5201, !dbg !270 + %6014 = load float, ptr addrspace(3) %6013, align 8, !dbg !270 + %6015 = getelementptr inbounds nuw i8, ptr addrspace(3) %6013, i32 4, !dbg !270 + %6016 = load float, ptr addrspace(3) %6015, align 4, !dbg !270 + %6017 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5207, !dbg !270 + %6018 = load float, ptr addrspace(3) %6017, align 8, !dbg !270 + %6019 = getelementptr inbounds nuw i8, ptr addrspace(3) %6017, i32 4, !dbg !270 + %6020 = load float, ptr addrspace(3) %6019, align 4, !dbg !270 + %6021 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5213, !dbg !270 + %6022 = load float, ptr addrspace(3) %6021, align 8, !dbg !270 + %6023 = getelementptr inbounds nuw i8, ptr addrspace(3) %6021, i32 4, !dbg !270 + %6024 = load float, ptr addrspace(3) %6023, align 4, !dbg !270 + %6025 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5219, !dbg !270 + %6026 = load float, ptr addrspace(3) %6025, align 8, !dbg !270 + %6027 = getelementptr inbounds nuw i8, ptr addrspace(3) %6025, i32 4, !dbg !270 + %6028 = load float, ptr addrspace(3) %6027, align 4, !dbg !270 + %6029 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5225, !dbg !270 + %6030 = load float, ptr addrspace(3) %6029, align 8, !dbg !270 + %6031 = getelementptr inbounds nuw i8, ptr addrspace(3) %6029, i32 4, !dbg !270 + %6032 = load float, ptr addrspace(3) %6031, align 4, !dbg !270 + %6033 = getelementptr inbounds nuw i8, ptr addrspace(3) %6004, i32 %5231, !dbg !270 + %6034 = load float, ptr addrspace(3) %6033, align 8, !dbg !270 + %6035 = getelementptr inbounds nuw i8, ptr addrspace(3) %6033, i32 4, !dbg !270 + %6036 = load float, ptr addrspace(3) %6035, align 4, !dbg !270 + %6037 = fcmp oeq float %6006, 0xFFF0000000000000, !dbg !278 + %6038 = fcmp oeq float %6008, 0xFFF0000000000000, !dbg !278 + %6039 = fcmp oeq float %6010, 0xFFF0000000000000, !dbg !278 + %6040 = fcmp oeq float %6012, 0xFFF0000000000000, !dbg !278 + %6041 = fcmp oeq float %6014, 0xFFF0000000000000, !dbg !278 + %6042 = fcmp oeq float %6016, 0xFFF0000000000000, !dbg !278 + %6043 = fcmp oeq float %6018, 0xFFF0000000000000, !dbg !278 + %6044 = fcmp oeq float %6020, 0xFFF0000000000000, !dbg !278 + %6045 = fcmp oeq float %6022, 0xFFF0000000000000, !dbg !278 + %6046 = fcmp oeq float %6024, 0xFFF0000000000000, !dbg !278 + %6047 = fcmp oeq float %6026, 0xFFF0000000000000, !dbg !278 + %6048 = fcmp oeq float %6028, 0xFFF0000000000000, !dbg !278 + %6049 = fcmp oeq float %6030, 0xFFF0000000000000, !dbg !278 + %6050 = fcmp oeq float %6032, 0xFFF0000000000000, !dbg !278 + %6051 = fcmp oeq float %6034, 0xFFF0000000000000, !dbg !278 + %6052 = fcmp oeq float %6036, 0xFFF0000000000000, !dbg !278 + %6053 = select i1 %6037, float 0.000000e+00, float %6006, !dbg !279 + %6054 = select i1 %6038, float 0.000000e+00, float %6008, !dbg !279 + %6055 = select i1 %6039, float 0.000000e+00, float %6010, !dbg !279 + %6056 = select i1 %6040, float 0.000000e+00, float %6012, !dbg !279 + %6057 = select i1 %6041, float 0.000000e+00, float %6014, !dbg !279 + %6058 = select i1 %6042, float 0.000000e+00, float %6016, !dbg !279 + %6059 = select i1 %6043, float 0.000000e+00, float %6018, !dbg !279 + %6060 = select i1 %6044, float 0.000000e+00, float %6020, !dbg !279 + %6061 = select i1 %6045, float 0.000000e+00, float %6022, !dbg !279 + %6062 = select i1 %6046, float 0.000000e+00, float %6024, !dbg !279 + %6063 = select i1 %6047, float 0.000000e+00, float %6026, !dbg !279 + %6064 = select i1 %6048, float 0.000000e+00, float %6028, !dbg !279 + %6065 = select i1 %6049, float 0.000000e+00, float %6030, !dbg !279 + %6066 = select i1 %6050, float 0.000000e+00, float %6032, !dbg !279 + %6067 = select i1 %6051, float 0.000000e+00, float %6034, !dbg !279 + %6068 = select i1 %6052, float 0.000000e+00, float %6036, !dbg !279 + %6069 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %51, i32 0, i32 31), !dbg !250 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !250 + %6070 = shl i32 %6069, 11, !dbg !250 + %6071 = and i32 %6070, 8192, !dbg !250 + %6072 = add i32 %6071, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6073 = lshr exact i32 %6072, 4, !dbg !250 + %6074 = and i32 %6073, 16383, !dbg !250 + %6075 = zext nneg i32 %6074 to i64, !dbg !250 + %6076 = or disjoint i64 %6075, 4611686293372403712, !dbg !250 + %6077 = ptrtoint ptr addrspace(3) %6002 to i32, !dbg !250 + %6078 = lshr exact i32 %6077, 4, !dbg !250 + %6079 = and i32 %6078, 16383, !dbg !250 + %6080 = zext nneg i32 %6079 to i64, !dbg !250 + %6081 = or disjoint i64 %6080, 4611686293338849280, !dbg !250 + %6082 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %6076, i64 %6081) #3, !dbg !250 + %6083 = or disjoint i32 %6071, 32, !dbg !250 + %6084 = add i32 %6083, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6085 = lshr exact i32 %6084, 4, !dbg !250 + %6086 = and i32 %6085, 16383, !dbg !250 + %6087 = zext nneg i32 %6086 to i64, !dbg !250 + %6088 = or disjoint i64 %6087, 4611686293372403712, !dbg !250 + %6089 = add i32 %6077, 32, !dbg !250 + %6090 = lshr exact i32 %6089, 4, !dbg !250 + %6091 = and i32 %6090, 16383, !dbg !250 + %6092 = zext nneg i32 %6091 to i64, !dbg !250 + %6093 = or disjoint i64 %6092, 4611686293338849280, !dbg !250 + %6094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 0, !dbg !250 + %6095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 1, !dbg !250 + %6096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 2, !dbg !250 + %6097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 3, !dbg !250 + %6098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 4, !dbg !250 + %6099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 5, !dbg !250 + %6100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 6, !dbg !250 + %6101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 7, !dbg !250 + %6102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 8, !dbg !250 + %6103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 9, !dbg !250 + %6104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 10, !dbg !250 + %6105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 11, !dbg !250 + %6106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 12, !dbg !250 + %6107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 13, !dbg !250 + %6108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 14, !dbg !250 + %6109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 15, !dbg !250 + %6110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 16, !dbg !250 + %6111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 17, !dbg !250 + %6112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 18, !dbg !250 + %6113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 19, !dbg !250 + %6114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 20, !dbg !250 + %6115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 21, !dbg !250 + %6116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 22, !dbg !250 + %6117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 23, !dbg !250 + %6118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 24, !dbg !250 + %6119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 25, !dbg !250 + %6120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 26, !dbg !250 + %6121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 27, !dbg !250 + %6122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 28, !dbg !250 + %6123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 29, !dbg !250 + %6124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 30, !dbg !250 + %6125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6082, 31, !dbg !250 + %6126 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6094, float %6095, float %6096, float %6097, float %6098, float %6099, float %6100, float %6101, float %6102, float %6103, float %6104, float %6105, float %6106, float %6107, float %6108, float %6109, float %6110, float %6111, float %6112, float %6113, float %6114, float %6115, float %6116, float %6117, float %6118, float %6119, float %6120, float %6121, float %6122, float %6123, float %6124, float %6125, i64 %6088, i64 %6093, i1 true) #3, !dbg !250 + %6127 = or disjoint i32 %6071, 64, !dbg !250 + %6128 = add i32 %6127, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6129 = lshr exact i32 %6128, 4, !dbg !250 + %6130 = and i32 %6129, 16383, !dbg !250 + %6131 = zext nneg i32 %6130 to i64, !dbg !250 + %6132 = or disjoint i64 %6131, 4611686293372403712, !dbg !250 + %6133 = add i32 %6077, 64, !dbg !250 + %6134 = lshr exact i32 %6133, 4, !dbg !250 + %6135 = and i32 %6134, 16383, !dbg !250 + %6136 = zext nneg i32 %6135 to i64, !dbg !250 + %6137 = or disjoint i64 %6136, 4611686293338849280, !dbg !250 + %6138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 0, !dbg !250 + %6139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 1, !dbg !250 + %6140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 2, !dbg !250 + %6141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 3, !dbg !250 + %6142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 4, !dbg !250 + %6143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 5, !dbg !250 + %6144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 6, !dbg !250 + %6145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 7, !dbg !250 + %6146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 8, !dbg !250 + %6147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 9, !dbg !250 + %6148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 10, !dbg !250 + %6149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 11, !dbg !250 + %6150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 12, !dbg !250 + %6151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 13, !dbg !250 + %6152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 14, !dbg !250 + %6153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 15, !dbg !250 + %6154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 16, !dbg !250 + %6155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 17, !dbg !250 + %6156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 18, !dbg !250 + %6157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 19, !dbg !250 + %6158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 20, !dbg !250 + %6159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 21, !dbg !250 + %6160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 22, !dbg !250 + %6161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 23, !dbg !250 + %6162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 24, !dbg !250 + %6163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 25, !dbg !250 + %6164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 26, !dbg !250 + %6165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 27, !dbg !250 + %6166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 28, !dbg !250 + %6167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 29, !dbg !250 + %6168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 30, !dbg !250 + %6169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6126, 31, !dbg !250 + %6170 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6138, float %6139, float %6140, float %6141, float %6142, float %6143, float %6144, float %6145, float %6146, float %6147, float %6148, float %6149, float %6150, float %6151, float %6152, float %6153, float %6154, float %6155, float %6156, float %6157, float %6158, float %6159, float %6160, float %6161, float %6162, float %6163, float %6164, float %6165, float %6166, float %6167, float %6168, float %6169, i64 %6132, i64 %6137, i1 true) #3, !dbg !250 + %6171 = or disjoint i32 %6071, 96, !dbg !250 + %6172 = add i32 %6171, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6173 = lshr exact i32 %6172, 4, !dbg !250 + %6174 = and i32 %6173, 16383, !dbg !250 + %6175 = zext nneg i32 %6174 to i64, !dbg !250 + %6176 = or disjoint i64 %6175, 4611686293372403712, !dbg !250 + %6177 = add i32 %6077, 96, !dbg !250 + %6178 = lshr exact i32 %6177, 4, !dbg !250 + %6179 = and i32 %6178, 16383, !dbg !250 + %6180 = zext nneg i32 %6179 to i64, !dbg !250 + %6181 = or disjoint i64 %6180, 4611686293338849280, !dbg !250 + %6182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 0, !dbg !250 + %6183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 1, !dbg !250 + %6184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 2, !dbg !250 + %6185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 3, !dbg !250 + %6186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 4, !dbg !250 + %6187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 5, !dbg !250 + %6188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 6, !dbg !250 + %6189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 7, !dbg !250 + %6190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 8, !dbg !250 + %6191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 9, !dbg !250 + %6192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 10, !dbg !250 + %6193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 11, !dbg !250 + %6194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 12, !dbg !250 + %6195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 13, !dbg !250 + %6196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 14, !dbg !250 + %6197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 15, !dbg !250 + %6198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 16, !dbg !250 + %6199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 17, !dbg !250 + %6200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 18, !dbg !250 + %6201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 19, !dbg !250 + %6202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 20, !dbg !250 + %6203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 21, !dbg !250 + %6204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 22, !dbg !250 + %6205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 23, !dbg !250 + %6206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 24, !dbg !250 + %6207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 25, !dbg !250 + %6208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 26, !dbg !250 + %6209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 27, !dbg !250 + %6210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 28, !dbg !250 + %6211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 29, !dbg !250 + %6212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 30, !dbg !250 + %6213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6170, 31, !dbg !250 + %6214 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6182, float %6183, float %6184, float %6185, float %6186, float %6187, float %6188, float %6189, float %6190, float %6191, float %6192, float %6193, float %6194, float %6195, float %6196, float %6197, float %6198, float %6199, float %6200, float %6201, float %6202, float %6203, float %6204, float %6205, float %6206, float %6207, float %6208, float %6209, float %6210, float %6211, float %6212, float %6213, i64 %6176, i64 %6181, i1 true) #3, !dbg !250 + %6215 = or disjoint i32 %6071, 16384, !dbg !250 + %6216 = add i32 %6215, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6217 = lshr exact i32 %6216, 4, !dbg !250 + %6218 = and i32 %6217, 16383, !dbg !250 + %6219 = zext nneg i32 %6218 to i64, !dbg !250 + %6220 = or disjoint i64 %6219, 4611686293372403712, !dbg !250 + %6221 = add i32 %6077, 8192, !dbg !250 + %6222 = lshr exact i32 %6221, 4, !dbg !250 + %6223 = and i32 %6222, 16383, !dbg !250 + %6224 = zext nneg i32 %6223 to i64, !dbg !250 + %6225 = or disjoint i64 %6224, 4611686293338849280, !dbg !250 + %6226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 0, !dbg !250 + %6227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 1, !dbg !250 + %6228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 2, !dbg !250 + %6229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 3, !dbg !250 + %6230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 4, !dbg !250 + %6231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 5, !dbg !250 + %6232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 6, !dbg !250 + %6233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 7, !dbg !250 + %6234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 8, !dbg !250 + %6235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 9, !dbg !250 + %6236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 10, !dbg !250 + %6237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 11, !dbg !250 + %6238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 12, !dbg !250 + %6239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 13, !dbg !250 + %6240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 14, !dbg !250 + %6241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 15, !dbg !250 + %6242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 16, !dbg !250 + %6243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 17, !dbg !250 + %6244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 18, !dbg !250 + %6245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 19, !dbg !250 + %6246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 20, !dbg !250 + %6247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 21, !dbg !250 + %6248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 22, !dbg !250 + %6249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 23, !dbg !250 + %6250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 24, !dbg !250 + %6251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 25, !dbg !250 + %6252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 26, !dbg !250 + %6253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 27, !dbg !250 + %6254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 28, !dbg !250 + %6255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 29, !dbg !250 + %6256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 30, !dbg !250 + %6257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6214, 31, !dbg !250 + %6258 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6226, float %6227, float %6228, float %6229, float %6230, float %6231, float %6232, float %6233, float %6234, float %6235, float %6236, float %6237, float %6238, float %6239, float %6240, float %6241, float %6242, float %6243, float %6244, float %6245, float %6246, float %6247, float %6248, float %6249, float %6250, float %6251, float %6252, float %6253, float %6254, float %6255, float %6256, float %6257, i64 %6220, i64 %6225, i1 true) #3, !dbg !250 + %6259 = or disjoint i32 %6071, 16416, !dbg !250 + %6260 = add i32 %6259, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6261 = lshr exact i32 %6260, 4, !dbg !250 + %6262 = and i32 %6261, 16383, !dbg !250 + %6263 = zext nneg i32 %6262 to i64, !dbg !250 + %6264 = or disjoint i64 %6263, 4611686293372403712, !dbg !250 + %6265 = add i32 %6077, 8224, !dbg !250 + %6266 = lshr exact i32 %6265, 4, !dbg !250 + %6267 = and i32 %6266, 16383, !dbg !250 + %6268 = zext nneg i32 %6267 to i64, !dbg !250 + %6269 = or disjoint i64 %6268, 4611686293338849280, !dbg !250 + %6270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 0, !dbg !250 + %6271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 1, !dbg !250 + %6272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 2, !dbg !250 + %6273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 3, !dbg !250 + %6274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 4, !dbg !250 + %6275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 5, !dbg !250 + %6276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 6, !dbg !250 + %6277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 7, !dbg !250 + %6278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 8, !dbg !250 + %6279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 9, !dbg !250 + %6280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 10, !dbg !250 + %6281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 11, !dbg !250 + %6282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 12, !dbg !250 + %6283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 13, !dbg !250 + %6284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 14, !dbg !250 + %6285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 15, !dbg !250 + %6286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 16, !dbg !250 + %6287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 17, !dbg !250 + %6288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 18, !dbg !250 + %6289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 19, !dbg !250 + %6290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 20, !dbg !250 + %6291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 21, !dbg !250 + %6292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 22, !dbg !250 + %6293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 23, !dbg !250 + %6294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 24, !dbg !250 + %6295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 25, !dbg !250 + %6296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 26, !dbg !250 + %6297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 27, !dbg !250 + %6298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 28, !dbg !250 + %6299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 29, !dbg !250 + %6300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 30, !dbg !250 + %6301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6258, 31, !dbg !250 + %6302 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6270, float %6271, float %6272, float %6273, float %6274, float %6275, float %6276, float %6277, float %6278, float %6279, float %6280, float %6281, float %6282, float %6283, float %6284, float %6285, float %6286, float %6287, float %6288, float %6289, float %6290, float %6291, float %6292, float %6293, float %6294, float %6295, float %6296, float %6297, float %6298, float %6299, float %6300, float %6301, i64 %6264, i64 %6269, i1 true) #3, !dbg !250 + %6303 = or disjoint i32 %6071, 16448, !dbg !250 + %6304 = add i32 %6303, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6305 = lshr exact i32 %6304, 4, !dbg !250 + %6306 = and i32 %6305, 16383, !dbg !250 + %6307 = zext nneg i32 %6306 to i64, !dbg !250 + %6308 = or disjoint i64 %6307, 4611686293372403712, !dbg !250 + %6309 = add i32 %6077, 8256, !dbg !250 + %6310 = lshr exact i32 %6309, 4, !dbg !250 + %6311 = and i32 %6310, 16383, !dbg !250 + %6312 = zext nneg i32 %6311 to i64, !dbg !250 + %6313 = or disjoint i64 %6312, 4611686293338849280, !dbg !250 + %6314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 0, !dbg !250 + %6315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 1, !dbg !250 + %6316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 2, !dbg !250 + %6317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 3, !dbg !250 + %6318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 4, !dbg !250 + %6319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 5, !dbg !250 + %6320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 6, !dbg !250 + %6321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 7, !dbg !250 + %6322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 8, !dbg !250 + %6323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 9, !dbg !250 + %6324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 10, !dbg !250 + %6325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 11, !dbg !250 + %6326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 12, !dbg !250 + %6327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 13, !dbg !250 + %6328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 14, !dbg !250 + %6329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 15, !dbg !250 + %6330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 16, !dbg !250 + %6331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 17, !dbg !250 + %6332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 18, !dbg !250 + %6333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 19, !dbg !250 + %6334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 20, !dbg !250 + %6335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 21, !dbg !250 + %6336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 22, !dbg !250 + %6337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 23, !dbg !250 + %6338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 24, !dbg !250 + %6339 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 25, !dbg !250 + %6340 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 26, !dbg !250 + %6341 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 27, !dbg !250 + %6342 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 28, !dbg !250 + %6343 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 29, !dbg !250 + %6344 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 30, !dbg !250 + %6345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6302, 31, !dbg !250 + %6346 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6314, float %6315, float %6316, float %6317, float %6318, float %6319, float %6320, float %6321, float %6322, float %6323, float %6324, float %6325, float %6326, float %6327, float %6328, float %6329, float %6330, float %6331, float %6332, float %6333, float %6334, float %6335, float %6336, float %6337, float %6338, float %6339, float %6340, float %6341, float %6342, float %6343, float %6344, float %6345, i64 %6308, i64 %6313, i1 true) #3, !dbg !250 + %6347 = or disjoint i32 %6071, 16480, !dbg !250 + %6348 = add i32 %6347, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !250 + %6349 = lshr exact i32 %6348, 4, !dbg !250 + %6350 = and i32 %6349, 16383, !dbg !250 + %6351 = zext nneg i32 %6350 to i64, !dbg !250 + %6352 = or disjoint i64 %6351, 4611686293372403712, !dbg !250 + %6353 = add i32 %6077, 8288, !dbg !250 + %6354 = lshr exact i32 %6353, 4, !dbg !250 + %6355 = and i32 %6354, 16383, !dbg !250 + %6356 = zext nneg i32 %6355 to i64, !dbg !250 + %6357 = or disjoint i64 %6356, 4611686293338849280, !dbg !250 + %6358 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 0, !dbg !250 + %6359 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 1, !dbg !250 + %6360 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 2, !dbg !250 + %6361 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 3, !dbg !250 + %6362 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 4, !dbg !250 + %6363 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 5, !dbg !250 + %6364 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 6, !dbg !250 + %6365 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 7, !dbg !250 + %6366 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 8, !dbg !250 + %6367 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 9, !dbg !250 + %6368 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 10, !dbg !250 + %6369 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 11, !dbg !250 + %6370 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 12, !dbg !250 + %6371 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 13, !dbg !250 + %6372 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 14, !dbg !250 + %6373 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 15, !dbg !250 + %6374 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 16, !dbg !250 + %6375 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 17, !dbg !250 + %6376 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 18, !dbg !250 + %6377 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 19, !dbg !250 + %6378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 20, !dbg !250 + %6379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 21, !dbg !250 + %6380 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 22, !dbg !250 + %6381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 23, !dbg !250 + %6382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 24, !dbg !250 + %6383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 25, !dbg !250 + %6384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 26, !dbg !250 + %6385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 27, !dbg !250 + %6386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 28, !dbg !250 + %6387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 29, !dbg !250 + %6388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 30, !dbg !250 + %6389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6346, 31, !dbg !250 + %6390 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6358, float %6359, float %6360, float %6361, float %6362, float %6363, float %6364, float %6365, float %6366, float %6367, float %6368, float %6369, float %6370, float %6371, float %6372, float %6373, float %6374, float %6375, float %6376, float %6377, float %6378, float %6379, float %6380, float %6381, float %6382, float %6383, float %6384, float %6385, float %6386, float %6387, float %6388, float %6389, i64 %6352, i64 %6357, i1 true) #3, !dbg !250 + %6391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 0, !dbg !250 + %6392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 1, !dbg !250 + %6393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 2, !dbg !250 + %6394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 3, !dbg !250 + %6395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 4, !dbg !250 + %6396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 5, !dbg !250 + %6397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 6, !dbg !250 + %6398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 7, !dbg !250 + %6399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 8, !dbg !250 + %6400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 9, !dbg !250 + %6401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 10, !dbg !250 + %6402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 11, !dbg !250 + %6403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 12, !dbg !250 + %6404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 13, !dbg !250 + %6405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 14, !dbg !250 + %6406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 15, !dbg !250 + %6407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 16, !dbg !250 + %6408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 17, !dbg !250 + %6409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 18, !dbg !250 + %6410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 19, !dbg !250 + %6411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 20, !dbg !250 + %6412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 21, !dbg !250 + %6413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 22, !dbg !250 + %6414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 23, !dbg !250 + %6415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 24, !dbg !250 + %6416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 25, !dbg !250 + %6417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 26, !dbg !250 + %6418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 27, !dbg !250 + %6419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 28, !dbg !250 + %6420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 29, !dbg !250 + %6421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 30, !dbg !250 + %6422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6390, 31, !dbg !250 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !250 + %6423 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %6391, float %6392, float %6393, float %6394, float %6395, float %6396, float %6397, float %6398, float %6399, float %6400, float %6401, float %6402, float %6403, float %6404, float %6405, float %6406, float %6407, float %6408, float %6409, float %6410, float %6411, float %6412, float %6413, float %6414, float %6415, float %6416, float %6417, float %6418, float %6419, float %6420, float %6421, float %6422, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %6002, i32 0, i32 0) #3, !dbg !250 + %6424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 0, !dbg !250 + %6425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 1, !dbg !250 + %6426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 2, !dbg !250 + %6427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 3, !dbg !250 + %6428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 4, !dbg !250 + %6429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 5, !dbg !250 + %6430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 6, !dbg !250 + %6431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 7, !dbg !250 + %6432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 8, !dbg !250 + %6433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 9, !dbg !250 + %6434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 10, !dbg !250 + %6435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 11, !dbg !250 + %6436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 12, !dbg !250 + %6437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 13, !dbg !250 + %6438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 14, !dbg !250 + %6439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 15, !dbg !250 + %6440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 16, !dbg !250 + %6441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 17, !dbg !250 + %6442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 18, !dbg !250 + %6443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 19, !dbg !250 + %6444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 20, !dbg !250 + %6445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 21, !dbg !250 + %6446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 22, !dbg !250 + %6447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 23, !dbg !250 + %6448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 24, !dbg !250 + %6449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 25, !dbg !250 + %6450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 26, !dbg !250 + %6451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 27, !dbg !250 + %6452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 28, !dbg !250 + %6453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 29, !dbg !250 + %6454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 30, !dbg !250 + %6455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6423, 31, !dbg !250 + %6456 = fmul float %6424, 0x3FB6A09E60000000, !dbg !280 + %6457 = fmul float %6425, 0x3FB6A09E60000000, !dbg !280 + %6458 = fmul float %6426, 0x3FB6A09E60000000, !dbg !280 + %6459 = fmul float %6427, 0x3FB6A09E60000000, !dbg !280 + %6460 = fmul float %6428, 0x3FB6A09E60000000, !dbg !280 + %6461 = fmul float %6429, 0x3FB6A09E60000000, !dbg !280 + %6462 = fmul float %6430, 0x3FB6A09E60000000, !dbg !280 + %6463 = fmul float %6431, 0x3FB6A09E60000000, !dbg !280 + %6464 = fmul float %6432, 0x3FB6A09E60000000, !dbg !280 + %6465 = fmul float %6433, 0x3FB6A09E60000000, !dbg !280 + %6466 = fmul float %6434, 0x3FB6A09E60000000, !dbg !280 + %6467 = fmul float %6435, 0x3FB6A09E60000000, !dbg !280 + %6468 = fmul float %6436, 0x3FB6A09E60000000, !dbg !280 + %6469 = fmul float %6437, 0x3FB6A09E60000000, !dbg !280 + %6470 = fmul float %6438, 0x3FB6A09E60000000, !dbg !280 + %6471 = fmul float %6439, 0x3FB6A09E60000000, !dbg !280 + %6472 = fmul float %6440, 0x3FB6A09E60000000, !dbg !280 + %6473 = fmul float %6441, 0x3FB6A09E60000000, !dbg !280 + %6474 = fmul float %6442, 0x3FB6A09E60000000, !dbg !280 + %6475 = fmul float %6443, 0x3FB6A09E60000000, !dbg !280 + %6476 = fmul float %6444, 0x3FB6A09E60000000, !dbg !280 + %6477 = fmul float %6445, 0x3FB6A09E60000000, !dbg !280 + %6478 = fmul float %6446, 0x3FB6A09E60000000, !dbg !280 + %6479 = fmul float %6447, 0x3FB6A09E60000000, !dbg !280 + %6480 = fmul float %6448, 0x3FB6A09E60000000, !dbg !280 + %6481 = fmul float %6449, 0x3FB6A09E60000000, !dbg !280 + %6482 = fmul float %6450, 0x3FB6A09E60000000, !dbg !280 + %6483 = fmul float %6451, 0x3FB6A09E60000000, !dbg !280 + %6484 = fmul float %6452, 0x3FB6A09E60000000, !dbg !280 + %6485 = fmul float %6453, 0x3FB6A09E60000000, !dbg !280 + %6486 = fmul float %6454, 0x3FB6A09E60000000, !dbg !280 + %6487 = fmul float %6455, 0x3FB6A09E60000000, !dbg !280 + %6488 = srem <16 x i32> %5976, %5582, !dbg !237 + %6489 = icmp slt <16 x i32> %6488, zeroinitializer, !dbg !281 + %6490 = extractelement <16 x i32> %6488, i64 15, !dbg !282 + %6491 = icmp sle i32 %5583, %6490, !dbg !283 + %6492 = extractelement <16 x i32> %6488, i64 14, !dbg !282 + %6493 = icmp sle i32 %5583, %6492, !dbg !283 + %6494 = icmp sle i32 %5584, %6490, !dbg !283 + %6495 = icmp sle i32 %5584, %6492, !dbg !283 + %6496 = extractelement <16 x i32> %6488, i64 13, !dbg !282 + %6497 = icmp sle i32 %5583, %6496, !dbg !283 + %6498 = extractelement <16 x i32> %6488, i64 12, !dbg !282 + %6499 = icmp sle i32 %5583, %6498, !dbg !283 + %6500 = icmp sle i32 %5584, %6496, !dbg !283 + %6501 = icmp sle i32 %5584, %6498, !dbg !283 + %6502 = extractelement <16 x i32> %6488, i64 11, !dbg !282 + %6503 = icmp sle i32 %5583, %6502, !dbg !283 + %6504 = extractelement <16 x i32> %6488, i64 10, !dbg !282 + %6505 = icmp sle i32 %5583, %6504, !dbg !283 + %6506 = icmp sle i32 %5584, %6502, !dbg !283 + %6507 = icmp sle i32 %5584, %6504, !dbg !283 + %6508 = extractelement <16 x i32> %6488, i64 9, !dbg !282 + %6509 = icmp sle i32 %5583, %6508, !dbg !283 + %6510 = extractelement <16 x i32> %6488, i64 8, !dbg !282 + %6511 = icmp sle i32 %5583, %6510, !dbg !283 + %6512 = icmp sle i32 %5584, %6508, !dbg !283 + %6513 = icmp sle i32 %5584, %6510, !dbg !283 + %6514 = extractelement <16 x i32> %6488, i64 7, !dbg !282 + %6515 = icmp sle i32 %5583, %6514, !dbg !283 + %6516 = extractelement <16 x i32> %6488, i64 6, !dbg !282 + %6517 = icmp sle i32 %5583, %6516, !dbg !283 + %6518 = icmp sle i32 %5584, %6514, !dbg !283 + %6519 = icmp sle i32 %5584, %6516, !dbg !283 + %6520 = extractelement <16 x i32> %6488, i64 5, !dbg !282 + %6521 = icmp sle i32 %5583, %6520, !dbg !283 + %6522 = extractelement <16 x i32> %6488, i64 4, !dbg !282 + %6523 = icmp sle i32 %5583, %6522, !dbg !283 + %6524 = icmp sle i32 %5584, %6520, !dbg !283 + %6525 = icmp sle i32 %5584, %6522, !dbg !283 + %6526 = extractelement <16 x i32> %6488, i64 3, !dbg !282 + %6527 = icmp sle i32 %5583, %6526, !dbg !283 + %6528 = extractelement <16 x i32> %6488, i64 2, !dbg !282 + %6529 = icmp sle i32 %5583, %6528, !dbg !283 + %6530 = icmp sle i32 %5584, %6526, !dbg !283 + %6531 = icmp sle i32 %5584, %6528, !dbg !283 + %6532 = extractelement <16 x i32> %6488, i64 1, !dbg !282 + %6533 = icmp sle i32 %5583, %6532, !dbg !283 + %6534 = extractelement <16 x i32> %6488, i64 0, !dbg !282 + %6535 = icmp sle i32 %5583, %6534, !dbg !283 + %6536 = icmp sle i32 %5584, %6532, !dbg !283 + %6537 = icmp sle i32 %5584, %6534, !dbg !283 + %6538 = extractelement <16 x i1> %6489, i64 15, !dbg !284 + %6539 = and i1 %6538, %6491, !dbg !284 + %6540 = extractelement <16 x i1> %6489, i64 14, !dbg !284 + %6541 = and i1 %6540, %6493, !dbg !284 + %6542 = and i1 %6538, %6494, !dbg !284 + %6543 = and i1 %6540, %6495, !dbg !284 + %6544 = extractelement <16 x i1> %6489, i64 13, !dbg !284 + %6545 = and i1 %6544, %6497, !dbg !284 + %6546 = extractelement <16 x i1> %6489, i64 12, !dbg !284 + %6547 = and i1 %6546, %6499, !dbg !284 + %6548 = and i1 %6544, %6500, !dbg !284 + %6549 = and i1 %6546, %6501, !dbg !284 + %6550 = extractelement <16 x i1> %6489, i64 11, !dbg !284 + %6551 = and i1 %6550, %6503, !dbg !284 + %6552 = extractelement <16 x i1> %6489, i64 10, !dbg !284 + %6553 = and i1 %6552, %6505, !dbg !284 + %6554 = and i1 %6550, %6506, !dbg !284 + %6555 = and i1 %6552, %6507, !dbg !284 + %6556 = extractelement <16 x i1> %6489, i64 9, !dbg !284 + %6557 = and i1 %6556, %6509, !dbg !284 + %6558 = extractelement <16 x i1> %6489, i64 8, !dbg !284 + %6559 = and i1 %6558, %6511, !dbg !284 + %6560 = and i1 %6556, %6512, !dbg !284 + %6561 = and i1 %6558, %6513, !dbg !284 + %6562 = extractelement <16 x i1> %6489, i64 7, !dbg !284 + %6563 = and i1 %6562, %6515, !dbg !284 + %6564 = extractelement <16 x i1> %6489, i64 6, !dbg !284 + %6565 = and i1 %6564, %6517, !dbg !284 + %6566 = and i1 %6562, %6518, !dbg !284 + %6567 = and i1 %6564, %6519, !dbg !284 + %6568 = extractelement <16 x i1> %6489, i64 5, !dbg !284 + %6569 = and i1 %6568, %6521, !dbg !284 + %6570 = extractelement <16 x i1> %6489, i64 4, !dbg !284 + %6571 = and i1 %6570, %6523, !dbg !284 + %6572 = and i1 %6568, %6524, !dbg !284 + %6573 = and i1 %6570, %6525, !dbg !284 + %6574 = extractelement <16 x i1> %6489, i64 3, !dbg !284 + %6575 = and i1 %6574, %6527, !dbg !284 + %6576 = extractelement <16 x i1> %6489, i64 2, !dbg !284 + %6577 = and i1 %6576, %6529, !dbg !284 + %6578 = and i1 %6574, %6530, !dbg !284 + %6579 = and i1 %6576, %6531, !dbg !284 + %6580 = extractelement <16 x i1> %6489, i64 1, !dbg !284 + %6581 = and i1 %6580, %6533, !dbg !284 + %6582 = extractelement <16 x i1> %6489, i64 0, !dbg !284 + %6583 = and i1 %6582, %6535, !dbg !284 + %6584 = and i1 %6580, %6536, !dbg !284 + %6585 = and i1 %6582, %6537, !dbg !284 + %6586 = icmp sgt i32 %6490, -1, !dbg !282 + %6587 = icmp sgt i32 %6492, -1, !dbg !282 + %6588 = icmp sgt i32 %6496, -1, !dbg !282 + %6589 = icmp sgt i32 %6498, -1, !dbg !282 + %6590 = icmp sgt i32 %6502, -1, !dbg !282 + %6591 = icmp sgt i32 %6504, -1, !dbg !282 + %6592 = icmp sgt i32 %6508, -1, !dbg !282 + %6593 = icmp sgt i32 %6510, -1, !dbg !282 + %6594 = icmp sgt i32 %6514, -1, !dbg !282 + %6595 = icmp sgt i32 %6516, -1, !dbg !282 + %6596 = icmp sgt i32 %6520, -1, !dbg !282 + %6597 = icmp sgt i32 %6522, -1, !dbg !282 + %6598 = icmp sgt i32 %6526, -1, !dbg !282 + %6599 = icmp sgt i32 %6528, -1, !dbg !282 + %6600 = icmp sgt i32 %6532, -1, !dbg !282 + %6601 = icmp sgt i32 %6534, -1, !dbg !282 + %6602 = and <16 x i32> %6488, splat (i32 15), !dbg !285 + %6603 = icmp ne <16 x i32> %6602, zeroinitializer, !dbg !285 + %6604 = sdiv <16 x i32> %6488, splat (i32 16), !dbg !286 + %6605 = and <16 x i1> %6489, %6603, !dbg !287 + %6606 = sext <16 x i1> %6605 to <16 x i32>, !dbg !287 + %6607 = add nsw <16 x i32> %6604, %6606, !dbg !287 + %6608 = shufflevector <16 x i32> %6607, <16 x i32> poison, <32 x i32> , !dbg !287 + %6609 = icmp eq <32 x i32> %6608, %5045, !dbg !288 + %6610 = extractelement <32 x i1> %6609, i64 31, !dbg !289 + %6611 = and i1 %6586, %6610, !dbg !289 + %6612 = extractelement <32 x i1> %6609, i64 30, !dbg !289 + %6613 = and i1 %6587, %6612, !dbg !289 + %6614 = extractelement <32 x i1> %6609, i64 29, !dbg !289 + %6615 = and i1 %6586, %6614, !dbg !289 + %6616 = extractelement <32 x i1> %6609, i64 28, !dbg !289 + %6617 = and i1 %6587, %6616, !dbg !289 + %6618 = extractelement <32 x i1> %6609, i64 27, !dbg !289 + %6619 = and i1 %6588, %6618, !dbg !289 + %6620 = extractelement <32 x i1> %6609, i64 26, !dbg !289 + %6621 = and i1 %6589, %6620, !dbg !289 + %6622 = extractelement <32 x i1> %6609, i64 25, !dbg !289 + %6623 = and i1 %6588, %6622, !dbg !289 + %6624 = extractelement <32 x i1> %6609, i64 24, !dbg !289 + %6625 = and i1 %6589, %6624, !dbg !289 + %6626 = extractelement <32 x i1> %6609, i64 23, !dbg !289 + %6627 = and i1 %6590, %6626, !dbg !289 + %6628 = extractelement <32 x i1> %6609, i64 22, !dbg !289 + %6629 = and i1 %6591, %6628, !dbg !289 + %6630 = extractelement <32 x i1> %6609, i64 21, !dbg !289 + %6631 = and i1 %6590, %6630, !dbg !289 + %6632 = extractelement <32 x i1> %6609, i64 20, !dbg !289 + %6633 = and i1 %6591, %6632, !dbg !289 + %6634 = extractelement <32 x i1> %6609, i64 19, !dbg !289 + %6635 = and i1 %6592, %6634, !dbg !289 + %6636 = extractelement <32 x i1> %6609, i64 18, !dbg !289 + %6637 = and i1 %6593, %6636, !dbg !289 + %6638 = extractelement <32 x i1> %6609, i64 17, !dbg !289 + %6639 = and i1 %6592, %6638, !dbg !289 + %6640 = extractelement <32 x i1> %6609, i64 16, !dbg !289 + %6641 = and i1 %6593, %6640, !dbg !289 + %6642 = extractelement <32 x i1> %6609, i64 15, !dbg !289 + %6643 = and i1 %6594, %6642, !dbg !289 + %6644 = extractelement <32 x i1> %6609, i64 14, !dbg !289 + %6645 = and i1 %6595, %6644, !dbg !289 + %6646 = extractelement <32 x i1> %6609, i64 13, !dbg !289 + %6647 = and i1 %6594, %6646, !dbg !289 + %6648 = extractelement <32 x i1> %6609, i64 12, !dbg !289 + %6649 = and i1 %6595, %6648, !dbg !289 + %6650 = extractelement <32 x i1> %6609, i64 11, !dbg !289 + %6651 = and i1 %6596, %6650, !dbg !289 + %6652 = extractelement <32 x i1> %6609, i64 10, !dbg !289 + %6653 = and i1 %6597, %6652, !dbg !289 + %6654 = extractelement <32 x i1> %6609, i64 9, !dbg !289 + %6655 = and i1 %6596, %6654, !dbg !289 + %6656 = extractelement <32 x i1> %6609, i64 8, !dbg !289 + %6657 = and i1 %6597, %6656, !dbg !289 + %6658 = extractelement <32 x i1> %6609, i64 7, !dbg !289 + %6659 = and i1 %6598, %6658, !dbg !289 + %6660 = extractelement <32 x i1> %6609, i64 6, !dbg !289 + %6661 = and i1 %6599, %6660, !dbg !289 + %6662 = extractelement <32 x i1> %6609, i64 5, !dbg !289 + %6663 = and i1 %6598, %6662, !dbg !289 + %6664 = extractelement <32 x i1> %6609, i64 4, !dbg !289 + %6665 = and i1 %6599, %6664, !dbg !289 + %6666 = extractelement <32 x i1> %6609, i64 3, !dbg !289 + %6667 = and i1 %6600, %6666, !dbg !289 + %6668 = extractelement <32 x i1> %6609, i64 2, !dbg !289 + %6669 = and i1 %6601, %6668, !dbg !289 + %6670 = extractelement <32 x i1> %6609, i64 1, !dbg !289 + %6671 = and i1 %6600, %6670, !dbg !289 + %6672 = extractelement <32 x i1> %6609, i64 0, !dbg !289 + %6673 = and i1 %6601, %6672, !dbg !289 + %6674 = or i1 %6539, %6611, !dbg !290 + %6675 = or i1 %6541, %6613, !dbg !290 + %6676 = or i1 %6542, %6615, !dbg !290 + %6677 = or i1 %6543, %6617, !dbg !290 + %6678 = or i1 %6545, %6619, !dbg !290 + %6679 = or i1 %6547, %6621, !dbg !290 + %6680 = or i1 %6548, %6623, !dbg !290 + %6681 = or i1 %6549, %6625, !dbg !290 + %6682 = or i1 %6551, %6627, !dbg !290 + %6683 = or i1 %6553, %6629, !dbg !290 + %6684 = or i1 %6554, %6631, !dbg !290 + %6685 = or i1 %6555, %6633, !dbg !290 + %6686 = or i1 %6557, %6635, !dbg !290 + %6687 = or i1 %6559, %6637, !dbg !290 + %6688 = or i1 %6560, %6639, !dbg !290 + %6689 = or i1 %6561, %6641, !dbg !290 + %6690 = or i1 %6563, %6643, !dbg !290 + %6691 = or i1 %6565, %6645, !dbg !290 + %6692 = or i1 %6566, %6647, !dbg !290 + %6693 = or i1 %6567, %6649, !dbg !290 + %6694 = or i1 %6569, %6651, !dbg !290 + %6695 = or i1 %6571, %6653, !dbg !290 + %6696 = or i1 %6572, %6655, !dbg !290 + %6697 = or i1 %6573, %6657, !dbg !290 + %6698 = or i1 %6575, %6659, !dbg !290 + %6699 = or i1 %6577, %6661, !dbg !290 + %6700 = or i1 %6578, %6663, !dbg !290 + %6701 = or i1 %6579, %6665, !dbg !290 + %6702 = or i1 %6581, %6667, !dbg !290 + %6703 = or i1 %6583, %6669, !dbg !290 + %6704 = or i1 %6584, %6671, !dbg !290 + %6705 = or i1 %6585, %6673, !dbg !290 + %6706 = select i1 %6674, i1 %5985, i1 false, !dbg !291 + %6707 = select i1 %6675, i1 %5986, i1 false, !dbg !291 + %6708 = select i1 %6676, i1 %5985, i1 false, !dbg !291 + %6709 = select i1 %6677, i1 %5986, i1 false, !dbg !291 + %6710 = select i1 %6678, i1 %5987, i1 false, !dbg !291 + %6711 = select i1 %6679, i1 %5988, i1 false, !dbg !291 + %6712 = select i1 %6680, i1 %5987, i1 false, !dbg !291 + %6713 = select i1 %6681, i1 %5988, i1 false, !dbg !291 + %6714 = select i1 %6682, i1 %5989, i1 false, !dbg !291 + %6715 = select i1 %6683, i1 %5990, i1 false, !dbg !291 + %6716 = select i1 %6684, i1 %5989, i1 false, !dbg !291 + %6717 = select i1 %6685, i1 %5990, i1 false, !dbg !291 + %6718 = select i1 %6686, i1 %5991, i1 false, !dbg !291 + %6719 = select i1 %6687, i1 %5992, i1 false, !dbg !291 + %6720 = select i1 %6688, i1 %5991, i1 false, !dbg !291 + %6721 = select i1 %6689, i1 %5992, i1 false, !dbg !291 + %6722 = select i1 %6690, i1 %5993, i1 false, !dbg !291 + %6723 = select i1 %6691, i1 %5994, i1 false, !dbg !291 + %6724 = select i1 %6692, i1 %5993, i1 false, !dbg !291 + %6725 = select i1 %6693, i1 %5994, i1 false, !dbg !291 + %6726 = select i1 %6694, i1 %5995, i1 false, !dbg !291 + %6727 = select i1 %6695, i1 %5996, i1 false, !dbg !291 + %6728 = select i1 %6696, i1 %5995, i1 false, !dbg !291 + %6729 = select i1 %6697, i1 %5996, i1 false, !dbg !291 + %6730 = select i1 %6698, i1 %5997, i1 false, !dbg !291 + %6731 = select i1 %6699, i1 %5998, i1 false, !dbg !291 + %6732 = select i1 %6700, i1 %5997, i1 false, !dbg !291 + %6733 = select i1 %6701, i1 %5998, i1 false, !dbg !291 + %6734 = select i1 %6702, i1 %5999, i1 false, !dbg !291 + %6735 = select i1 %6703, i1 %6000, i1 false, !dbg !291 + %6736 = select i1 %6704, i1 %5999, i1 false, !dbg !291 + %6737 = select i1 %6705, i1 %6000, i1 false, !dbg !291 + %6738 = fmul float %6456, 0x3FF7154760000000, !dbg !292 + %6739 = select i1 %6706, float %6738, float 0xFFF0000000000000, !dbg !291 + %6740 = fmul float %6457, 0x3FF7154760000000, !dbg !292 + %6741 = select i1 %6707, float %6740, float 0xFFF0000000000000, !dbg !291 + %6742 = fmul float %6458, 0x3FF7154760000000, !dbg !292 + %6743 = select i1 %6708, float %6742, float 0xFFF0000000000000, !dbg !291 + %6744 = fmul float %6459, 0x3FF7154760000000, !dbg !292 + %6745 = select i1 %6709, float %6744, float 0xFFF0000000000000, !dbg !291 + %6746 = fmul float %6460, 0x3FF7154760000000, !dbg !292 + %6747 = select i1 %6710, float %6746, float 0xFFF0000000000000, !dbg !291 + %6748 = fmul float %6461, 0x3FF7154760000000, !dbg !292 + %6749 = select i1 %6711, float %6748, float 0xFFF0000000000000, !dbg !291 + %6750 = fmul float %6462, 0x3FF7154760000000, !dbg !292 + %6751 = select i1 %6712, float %6750, float 0xFFF0000000000000, !dbg !291 + %6752 = fmul float %6463, 0x3FF7154760000000, !dbg !292 + %6753 = select i1 %6713, float %6752, float 0xFFF0000000000000, !dbg !291 + %6754 = fmul float %6464, 0x3FF7154760000000, !dbg !292 + %6755 = select i1 %6714, float %6754, float 0xFFF0000000000000, !dbg !291 + %6756 = fmul float %6465, 0x3FF7154760000000, !dbg !292 + %6757 = select i1 %6715, float %6756, float 0xFFF0000000000000, !dbg !291 + %6758 = fmul float %6466, 0x3FF7154760000000, !dbg !292 + %6759 = select i1 %6716, float %6758, float 0xFFF0000000000000, !dbg !291 + %6760 = fmul float %6467, 0x3FF7154760000000, !dbg !292 + %6761 = select i1 %6717, float %6760, float 0xFFF0000000000000, !dbg !291 + %6762 = fmul float %6468, 0x3FF7154760000000, !dbg !292 + %6763 = select i1 %6718, float %6762, float 0xFFF0000000000000, !dbg !291 + %6764 = fmul float %6469, 0x3FF7154760000000, !dbg !292 + %6765 = select i1 %6719, float %6764, float 0xFFF0000000000000, !dbg !291 + %6766 = fmul float %6470, 0x3FF7154760000000, !dbg !292 + %6767 = select i1 %6720, float %6766, float 0xFFF0000000000000, !dbg !291 + %6768 = fmul float %6471, 0x3FF7154760000000, !dbg !292 + %6769 = select i1 %6721, float %6768, float 0xFFF0000000000000, !dbg !291 + %6770 = fmul float %6472, 0x3FF7154760000000, !dbg !292 + %6771 = select i1 %6722, float %6770, float 0xFFF0000000000000, !dbg !291 + %6772 = fmul float %6473, 0x3FF7154760000000, !dbg !292 + %6773 = select i1 %6723, float %6772, float 0xFFF0000000000000, !dbg !291 + %6774 = fmul float %6474, 0x3FF7154760000000, !dbg !292 + %6775 = select i1 %6724, float %6774, float 0xFFF0000000000000, !dbg !291 + %6776 = fmul float %6475, 0x3FF7154760000000, !dbg !292 + %6777 = select i1 %6725, float %6776, float 0xFFF0000000000000, !dbg !291 + %6778 = fmul float %6476, 0x3FF7154760000000, !dbg !292 + %6779 = select i1 %6726, float %6778, float 0xFFF0000000000000, !dbg !291 + %6780 = fmul float %6477, 0x3FF7154760000000, !dbg !292 + %6781 = select i1 %6727, float %6780, float 0xFFF0000000000000, !dbg !291 + %6782 = fmul float %6478, 0x3FF7154760000000, !dbg !292 + %6783 = select i1 %6728, float %6782, float 0xFFF0000000000000, !dbg !291 + %6784 = fmul float %6479, 0x3FF7154760000000, !dbg !292 + %6785 = select i1 %6729, float %6784, float 0xFFF0000000000000, !dbg !291 + %6786 = fmul float %6480, 0x3FF7154760000000, !dbg !292 + %6787 = select i1 %6730, float %6786, float 0xFFF0000000000000, !dbg !291 + %6788 = fmul float %6481, 0x3FF7154760000000, !dbg !292 + %6789 = select i1 %6731, float %6788, float 0xFFF0000000000000, !dbg !291 + %6790 = fmul float %6482, 0x3FF7154760000000, !dbg !292 + %6791 = select i1 %6732, float %6790, float 0xFFF0000000000000, !dbg !291 + %6792 = fmul float %6483, 0x3FF7154760000000, !dbg !292 + %6793 = select i1 %6733, float %6792, float 0xFFF0000000000000, !dbg !291 + %6794 = fmul float %6484, 0x3FF7154760000000, !dbg !292 + %6795 = select i1 %6734, float %6794, float 0xFFF0000000000000, !dbg !291 + %6796 = fmul float %6485, 0x3FF7154760000000, !dbg !292 + %6797 = select i1 %6735, float %6796, float 0xFFF0000000000000, !dbg !291 + %6798 = fmul float %6486, 0x3FF7154760000000, !dbg !292 + %6799 = select i1 %6736, float %6798, float 0xFFF0000000000000, !dbg !291 + %6800 = fmul float %6487, 0x3FF7154760000000, !dbg !292 + %6801 = select i1 %6737, float %6800, float 0xFFF0000000000000, !dbg !291 + %6802 = fsub float %6739, %6053, !dbg !293 + %6803 = fsub float %6741, %6054, !dbg !293 + %6804 = fsub float %6743, %6053, !dbg !293 + %6805 = fsub float %6745, %6054, !dbg !293 + %6806 = fsub float %6747, %6055, !dbg !293 + %6807 = fsub float %6749, %6056, !dbg !293 + %6808 = fsub float %6751, %6055, !dbg !293 + %6809 = fsub float %6753, %6056, !dbg !293 + %6810 = fsub float %6755, %6057, !dbg !293 + %6811 = fsub float %6757, %6058, !dbg !293 + %6812 = fsub float %6759, %6057, !dbg !293 + %6813 = fsub float %6761, %6058, !dbg !293 + %6814 = fsub float %6763, %6059, !dbg !293 + %6815 = fsub float %6765, %6060, !dbg !293 + %6816 = fsub float %6767, %6059, !dbg !293 + %6817 = fsub float %6769, %6060, !dbg !293 + %6818 = fsub float %6771, %6061, !dbg !293 + %6819 = fsub float %6773, %6062, !dbg !293 + %6820 = fsub float %6775, %6061, !dbg !293 + %6821 = fsub float %6777, %6062, !dbg !293 + %6822 = fsub float %6779, %6063, !dbg !293 + %6823 = fsub float %6781, %6064, !dbg !293 + %6824 = fsub float %6783, %6063, !dbg !293 + %6825 = fsub float %6785, %6064, !dbg !293 + %6826 = fsub float %6787, %6065, !dbg !293 + %6827 = fsub float %6789, %6066, !dbg !293 + %6828 = fsub float %6791, %6065, !dbg !293 + %6829 = fsub float %6793, %6066, !dbg !293 + %6830 = fsub float %6795, %6067, !dbg !293 + %6831 = fsub float %6797, %6068, !dbg !293 + %6832 = fsub float %6799, %6067, !dbg !293 + %6833 = fsub float %6801, %6068, !dbg !293 + %6834 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1322 = icmp eq i32 %6834, 0, !dbg !294 + br i1 %.not.i1322, label %6837, label %6835, !dbg !294 + +6835: ; preds = %.lr.ph1700 + %6836 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6802) #3, !dbg !294 + br label %__nv_exp2f.exit1324, !dbg !294 + +6837: ; preds = %.lr.ph1700 + %6838 = tail call float @llvm.nvvm.ex2.approx.f(float %6802) #3, !dbg !294 + br label %__nv_exp2f.exit1324, !dbg !294 + +__nv_exp2f.exit1324: ; preds = %6835, %6837 + %.0.i1323 = phi float [ %6836, %6835 ], [ %6838, %6837 ], !dbg !294 + %6839 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1325 = icmp eq i32 %6839, 0, !dbg !294 + br i1 %.not.i1325, label %6842, label %6840, !dbg !294 + +6840: ; preds = %__nv_exp2f.exit1324 + %6841 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6803) #3, !dbg !294 + br label %__nv_exp2f.exit1327, !dbg !294 + +6842: ; preds = %__nv_exp2f.exit1324 + %6843 = tail call float @llvm.nvvm.ex2.approx.f(float %6803) #3, !dbg !294 + br label %__nv_exp2f.exit1327, !dbg !294 + +__nv_exp2f.exit1327: ; preds = %6840, %6842 + %.0.i1326 = phi float [ %6841, %6840 ], [ %6843, %6842 ], !dbg !294 + %6844 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1328 = icmp eq i32 %6844, 0, !dbg !294 + br i1 %.not.i1328, label %6847, label %6845, !dbg !294 + +6845: ; preds = %__nv_exp2f.exit1327 + %6846 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6804) #3, !dbg !294 + br label %__nv_exp2f.exit1330, !dbg !294 + +6847: ; preds = %__nv_exp2f.exit1327 + %6848 = tail call float @llvm.nvvm.ex2.approx.f(float %6804) #3, !dbg !294 + br label %__nv_exp2f.exit1330, !dbg !294 + +__nv_exp2f.exit1330: ; preds = %6845, %6847 + %.0.i1329 = phi float [ %6846, %6845 ], [ %6848, %6847 ], !dbg !294 + %6849 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1331 = icmp eq i32 %6849, 0, !dbg !294 + br i1 %.not.i1331, label %6852, label %6850, !dbg !294 + +6850: ; preds = %__nv_exp2f.exit1330 + %6851 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6805) #3, !dbg !294 + br label %__nv_exp2f.exit1333, !dbg !294 + +6852: ; preds = %__nv_exp2f.exit1330 + %6853 = tail call float @llvm.nvvm.ex2.approx.f(float %6805) #3, !dbg !294 + br label %__nv_exp2f.exit1333, !dbg !294 + +__nv_exp2f.exit1333: ; preds = %6850, %6852 + %.0.i1332 = phi float [ %6851, %6850 ], [ %6853, %6852 ], !dbg !294 + %6854 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1334 = icmp eq i32 %6854, 0, !dbg !294 + br i1 %.not.i1334, label %6857, label %6855, !dbg !294 + +6855: ; preds = %__nv_exp2f.exit1333 + %6856 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6806) #3, !dbg !294 + br label %__nv_exp2f.exit1336, !dbg !294 + +6857: ; preds = %__nv_exp2f.exit1333 + %6858 = tail call float @llvm.nvvm.ex2.approx.f(float %6806) #3, !dbg !294 + br label %__nv_exp2f.exit1336, !dbg !294 + +__nv_exp2f.exit1336: ; preds = %6855, %6857 + %.0.i1335 = phi float [ %6856, %6855 ], [ %6858, %6857 ], !dbg !294 + %6859 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1337 = icmp eq i32 %6859, 0, !dbg !294 + br i1 %.not.i1337, label %6862, label %6860, !dbg !294 + +6860: ; preds = %__nv_exp2f.exit1336 + %6861 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6807) #3, !dbg !294 + br label %__nv_exp2f.exit1339, !dbg !294 + +6862: ; preds = %__nv_exp2f.exit1336 + %6863 = tail call float @llvm.nvvm.ex2.approx.f(float %6807) #3, !dbg !294 + br label %__nv_exp2f.exit1339, !dbg !294 + +__nv_exp2f.exit1339: ; preds = %6860, %6862 + %.0.i1338 = phi float [ %6861, %6860 ], [ %6863, %6862 ], !dbg !294 + %6864 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1340 = icmp eq i32 %6864, 0, !dbg !294 + br i1 %.not.i1340, label %6867, label %6865, !dbg !294 + +6865: ; preds = %__nv_exp2f.exit1339 + %6866 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6808) #3, !dbg !294 + br label %__nv_exp2f.exit1342, !dbg !294 + +6867: ; preds = %__nv_exp2f.exit1339 + %6868 = tail call float @llvm.nvvm.ex2.approx.f(float %6808) #3, !dbg !294 + br label %__nv_exp2f.exit1342, !dbg !294 + +__nv_exp2f.exit1342: ; preds = %6865, %6867 + %.0.i1341 = phi float [ %6866, %6865 ], [ %6868, %6867 ], !dbg !294 + %6869 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1343 = icmp eq i32 %6869, 0, !dbg !294 + br i1 %.not.i1343, label %6872, label %6870, !dbg !294 + +6870: ; preds = %__nv_exp2f.exit1342 + %6871 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6809) #3, !dbg !294 + br label %__nv_exp2f.exit1345, !dbg !294 + +6872: ; preds = %__nv_exp2f.exit1342 + %6873 = tail call float @llvm.nvvm.ex2.approx.f(float %6809) #3, !dbg !294 + br label %__nv_exp2f.exit1345, !dbg !294 + +__nv_exp2f.exit1345: ; preds = %6870, %6872 + %.0.i1344 = phi float [ %6871, %6870 ], [ %6873, %6872 ], !dbg !294 + %6874 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1346 = icmp eq i32 %6874, 0, !dbg !294 + br i1 %.not.i1346, label %6877, label %6875, !dbg !294 + +6875: ; preds = %__nv_exp2f.exit1345 + %6876 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6810) #3, !dbg !294 + br label %__nv_exp2f.exit1348, !dbg !294 + +6877: ; preds = %__nv_exp2f.exit1345 + %6878 = tail call float @llvm.nvvm.ex2.approx.f(float %6810) #3, !dbg !294 + br label %__nv_exp2f.exit1348, !dbg !294 + +__nv_exp2f.exit1348: ; preds = %6875, %6877 + %.0.i1347 = phi float [ %6876, %6875 ], [ %6878, %6877 ], !dbg !294 + %6879 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1349 = icmp eq i32 %6879, 0, !dbg !294 + br i1 %.not.i1349, label %6882, label %6880, !dbg !294 + +6880: ; preds = %__nv_exp2f.exit1348 + %6881 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6811) #3, !dbg !294 + br label %__nv_exp2f.exit1351, !dbg !294 + +6882: ; preds = %__nv_exp2f.exit1348 + %6883 = tail call float @llvm.nvvm.ex2.approx.f(float %6811) #3, !dbg !294 + br label %__nv_exp2f.exit1351, !dbg !294 + +__nv_exp2f.exit1351: ; preds = %6880, %6882 + %.0.i1350 = phi float [ %6881, %6880 ], [ %6883, %6882 ], !dbg !294 + %6884 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1352 = icmp eq i32 %6884, 0, !dbg !294 + br i1 %.not.i1352, label %6887, label %6885, !dbg !294 + +6885: ; preds = %__nv_exp2f.exit1351 + %6886 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6812) #3, !dbg !294 + br label %__nv_exp2f.exit1354, !dbg !294 + +6887: ; preds = %__nv_exp2f.exit1351 + %6888 = tail call float @llvm.nvvm.ex2.approx.f(float %6812) #3, !dbg !294 + br label %__nv_exp2f.exit1354, !dbg !294 + +__nv_exp2f.exit1354: ; preds = %6885, %6887 + %.0.i1353 = phi float [ %6886, %6885 ], [ %6888, %6887 ], !dbg !294 + %6889 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1355 = icmp eq i32 %6889, 0, !dbg !294 + br i1 %.not.i1355, label %6892, label %6890, !dbg !294 + +6890: ; preds = %__nv_exp2f.exit1354 + %6891 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6813) #3, !dbg !294 + br label %__nv_exp2f.exit1357, !dbg !294 + +6892: ; preds = %__nv_exp2f.exit1354 + %6893 = tail call float @llvm.nvvm.ex2.approx.f(float %6813) #3, !dbg !294 + br label %__nv_exp2f.exit1357, !dbg !294 + +__nv_exp2f.exit1357: ; preds = %6890, %6892 + %.0.i1356 = phi float [ %6891, %6890 ], [ %6893, %6892 ], !dbg !294 + %6894 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1358 = icmp eq i32 %6894, 0, !dbg !294 + br i1 %.not.i1358, label %6897, label %6895, !dbg !294 + +6895: ; preds = %__nv_exp2f.exit1357 + %6896 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6814) #3, !dbg !294 + br label %__nv_exp2f.exit1360, !dbg !294 + +6897: ; preds = %__nv_exp2f.exit1357 + %6898 = tail call float @llvm.nvvm.ex2.approx.f(float %6814) #3, !dbg !294 + br label %__nv_exp2f.exit1360, !dbg !294 + +__nv_exp2f.exit1360: ; preds = %6895, %6897 + %.0.i1359 = phi float [ %6896, %6895 ], [ %6898, %6897 ], !dbg !294 + %6899 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1361 = icmp eq i32 %6899, 0, !dbg !294 + br i1 %.not.i1361, label %6902, label %6900, !dbg !294 + +6900: ; preds = %__nv_exp2f.exit1360 + %6901 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6815) #3, !dbg !294 + br label %__nv_exp2f.exit1363, !dbg !294 + +6902: ; preds = %__nv_exp2f.exit1360 + %6903 = tail call float @llvm.nvvm.ex2.approx.f(float %6815) #3, !dbg !294 + br label %__nv_exp2f.exit1363, !dbg !294 + +__nv_exp2f.exit1363: ; preds = %6900, %6902 + %.0.i1362 = phi float [ %6901, %6900 ], [ %6903, %6902 ], !dbg !294 + %6904 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1364 = icmp eq i32 %6904, 0, !dbg !294 + br i1 %.not.i1364, label %6907, label %6905, !dbg !294 + +6905: ; preds = %__nv_exp2f.exit1363 + %6906 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6816) #3, !dbg !294 + br label %__nv_exp2f.exit1366, !dbg !294 + +6907: ; preds = %__nv_exp2f.exit1363 + %6908 = tail call float @llvm.nvvm.ex2.approx.f(float %6816) #3, !dbg !294 + br label %__nv_exp2f.exit1366, !dbg !294 + +__nv_exp2f.exit1366: ; preds = %6905, %6907 + %.0.i1365 = phi float [ %6906, %6905 ], [ %6908, %6907 ], !dbg !294 + %6909 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1367 = icmp eq i32 %6909, 0, !dbg !294 + br i1 %.not.i1367, label %6912, label %6910, !dbg !294 + +6910: ; preds = %__nv_exp2f.exit1366 + %6911 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6817) #3, !dbg !294 + br label %__nv_exp2f.exit1369, !dbg !294 + +6912: ; preds = %__nv_exp2f.exit1366 + %6913 = tail call float @llvm.nvvm.ex2.approx.f(float %6817) #3, !dbg !294 + br label %__nv_exp2f.exit1369, !dbg !294 + +__nv_exp2f.exit1369: ; preds = %6910, %6912 + %.0.i1368 = phi float [ %6911, %6910 ], [ %6913, %6912 ], !dbg !294 + %6914 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1370 = icmp eq i32 %6914, 0, !dbg !294 + br i1 %.not.i1370, label %6917, label %6915, !dbg !294 + +6915: ; preds = %__nv_exp2f.exit1369 + %6916 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6818) #3, !dbg !294 + br label %__nv_exp2f.exit1372, !dbg !294 + +6917: ; preds = %__nv_exp2f.exit1369 + %6918 = tail call float @llvm.nvvm.ex2.approx.f(float %6818) #3, !dbg !294 + br label %__nv_exp2f.exit1372, !dbg !294 + +__nv_exp2f.exit1372: ; preds = %6915, %6917 + %.0.i1371 = phi float [ %6916, %6915 ], [ %6918, %6917 ], !dbg !294 + %6919 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1373 = icmp eq i32 %6919, 0, !dbg !294 + br i1 %.not.i1373, label %6922, label %6920, !dbg !294 + +6920: ; preds = %__nv_exp2f.exit1372 + %6921 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6819) #3, !dbg !294 + br label %__nv_exp2f.exit1375, !dbg !294 + +6922: ; preds = %__nv_exp2f.exit1372 + %6923 = tail call float @llvm.nvvm.ex2.approx.f(float %6819) #3, !dbg !294 + br label %__nv_exp2f.exit1375, !dbg !294 + +__nv_exp2f.exit1375: ; preds = %6920, %6922 + %.0.i1374 = phi float [ %6921, %6920 ], [ %6923, %6922 ], !dbg !294 + %6924 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1376 = icmp eq i32 %6924, 0, !dbg !294 + br i1 %.not.i1376, label %6927, label %6925, !dbg !294 + +6925: ; preds = %__nv_exp2f.exit1375 + %6926 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6820) #3, !dbg !294 + br label %__nv_exp2f.exit1378, !dbg !294 + +6927: ; preds = %__nv_exp2f.exit1375 + %6928 = tail call float @llvm.nvvm.ex2.approx.f(float %6820) #3, !dbg !294 + br label %__nv_exp2f.exit1378, !dbg !294 + +__nv_exp2f.exit1378: ; preds = %6925, %6927 + %.0.i1377 = phi float [ %6926, %6925 ], [ %6928, %6927 ], !dbg !294 + %6929 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1379 = icmp eq i32 %6929, 0, !dbg !294 + br i1 %.not.i1379, label %6932, label %6930, !dbg !294 + +6930: ; preds = %__nv_exp2f.exit1378 + %6931 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6821) #3, !dbg !294 + br label %__nv_exp2f.exit1381, !dbg !294 + +6932: ; preds = %__nv_exp2f.exit1378 + %6933 = tail call float @llvm.nvvm.ex2.approx.f(float %6821) #3, !dbg !294 + br label %__nv_exp2f.exit1381, !dbg !294 + +__nv_exp2f.exit1381: ; preds = %6930, %6932 + %.0.i1380 = phi float [ %6931, %6930 ], [ %6933, %6932 ], !dbg !294 + %6934 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1382 = icmp eq i32 %6934, 0, !dbg !294 + br i1 %.not.i1382, label %6937, label %6935, !dbg !294 + +6935: ; preds = %__nv_exp2f.exit1381 + %6936 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6822) #3, !dbg !294 + br label %__nv_exp2f.exit1384, !dbg !294 + +6937: ; preds = %__nv_exp2f.exit1381 + %6938 = tail call float @llvm.nvvm.ex2.approx.f(float %6822) #3, !dbg !294 + br label %__nv_exp2f.exit1384, !dbg !294 + +__nv_exp2f.exit1384: ; preds = %6935, %6937 + %.0.i1383 = phi float [ %6936, %6935 ], [ %6938, %6937 ], !dbg !294 + %6939 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1385 = icmp eq i32 %6939, 0, !dbg !294 + br i1 %.not.i1385, label %6942, label %6940, !dbg !294 + +6940: ; preds = %__nv_exp2f.exit1384 + %6941 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6823) #3, !dbg !294 + br label %__nv_exp2f.exit1387, !dbg !294 + +6942: ; preds = %__nv_exp2f.exit1384 + %6943 = tail call float @llvm.nvvm.ex2.approx.f(float %6823) #3, !dbg !294 + br label %__nv_exp2f.exit1387, !dbg !294 + +__nv_exp2f.exit1387: ; preds = %6940, %6942 + %.0.i1386 = phi float [ %6941, %6940 ], [ %6943, %6942 ], !dbg !294 + %6944 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1388 = icmp eq i32 %6944, 0, !dbg !294 + br i1 %.not.i1388, label %6947, label %6945, !dbg !294 + +6945: ; preds = %__nv_exp2f.exit1387 + %6946 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6824) #3, !dbg !294 + br label %__nv_exp2f.exit1390, !dbg !294 + +6947: ; preds = %__nv_exp2f.exit1387 + %6948 = tail call float @llvm.nvvm.ex2.approx.f(float %6824) #3, !dbg !294 + br label %__nv_exp2f.exit1390, !dbg !294 + +__nv_exp2f.exit1390: ; preds = %6945, %6947 + %.0.i1389 = phi float [ %6946, %6945 ], [ %6948, %6947 ], !dbg !294 + %6949 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1391 = icmp eq i32 %6949, 0, !dbg !294 + br i1 %.not.i1391, label %6952, label %6950, !dbg !294 + +6950: ; preds = %__nv_exp2f.exit1390 + %6951 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6825) #3, !dbg !294 + br label %__nv_exp2f.exit1393, !dbg !294 + +6952: ; preds = %__nv_exp2f.exit1390 + %6953 = tail call float @llvm.nvvm.ex2.approx.f(float %6825) #3, !dbg !294 + br label %__nv_exp2f.exit1393, !dbg !294 + +__nv_exp2f.exit1393: ; preds = %6950, %6952 + %.0.i1392 = phi float [ %6951, %6950 ], [ %6953, %6952 ], !dbg !294 + %6954 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1394 = icmp eq i32 %6954, 0, !dbg !294 + br i1 %.not.i1394, label %6957, label %6955, !dbg !294 + +6955: ; preds = %__nv_exp2f.exit1393 + %6956 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6826) #3, !dbg !294 + br label %__nv_exp2f.exit1396, !dbg !294 + +6957: ; preds = %__nv_exp2f.exit1393 + %6958 = tail call float @llvm.nvvm.ex2.approx.f(float %6826) #3, !dbg !294 + br label %__nv_exp2f.exit1396, !dbg !294 + +__nv_exp2f.exit1396: ; preds = %6955, %6957 + %.0.i1395 = phi float [ %6956, %6955 ], [ %6958, %6957 ], !dbg !294 + %6959 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1397 = icmp eq i32 %6959, 0, !dbg !294 + br i1 %.not.i1397, label %6962, label %6960, !dbg !294 + +6960: ; preds = %__nv_exp2f.exit1396 + %6961 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6827) #3, !dbg !294 + br label %__nv_exp2f.exit1399, !dbg !294 + +6962: ; preds = %__nv_exp2f.exit1396 + %6963 = tail call float @llvm.nvvm.ex2.approx.f(float %6827) #3, !dbg !294 + br label %__nv_exp2f.exit1399, !dbg !294 + +__nv_exp2f.exit1399: ; preds = %6960, %6962 + %.0.i1398 = phi float [ %6961, %6960 ], [ %6963, %6962 ], !dbg !294 + %6964 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1400 = icmp eq i32 %6964, 0, !dbg !294 + br i1 %.not.i1400, label %6967, label %6965, !dbg !294 + +6965: ; preds = %__nv_exp2f.exit1399 + %6966 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6828) #3, !dbg !294 + br label %__nv_exp2f.exit1402, !dbg !294 + +6967: ; preds = %__nv_exp2f.exit1399 + %6968 = tail call float @llvm.nvvm.ex2.approx.f(float %6828) #3, !dbg !294 + br label %__nv_exp2f.exit1402, !dbg !294 + +__nv_exp2f.exit1402: ; preds = %6965, %6967 + %.0.i1401 = phi float [ %6966, %6965 ], [ %6968, %6967 ], !dbg !294 + %6969 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1403 = icmp eq i32 %6969, 0, !dbg !294 + br i1 %.not.i1403, label %6972, label %6970, !dbg !294 + +6970: ; preds = %__nv_exp2f.exit1402 + %6971 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6829) #3, !dbg !294 + br label %__nv_exp2f.exit1405, !dbg !294 + +6972: ; preds = %__nv_exp2f.exit1402 + %6973 = tail call float @llvm.nvvm.ex2.approx.f(float %6829) #3, !dbg !294 + br label %__nv_exp2f.exit1405, !dbg !294 + +__nv_exp2f.exit1405: ; preds = %6970, %6972 + %.0.i1404 = phi float [ %6971, %6970 ], [ %6973, %6972 ], !dbg !294 + %6974 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1406 = icmp eq i32 %6974, 0, !dbg !294 + br i1 %.not.i1406, label %6977, label %6975, !dbg !294 + +6975: ; preds = %__nv_exp2f.exit1405 + %6976 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6830) #3, !dbg !294 + br label %__nv_exp2f.exit1408, !dbg !294 + +6977: ; preds = %__nv_exp2f.exit1405 + %6978 = tail call float @llvm.nvvm.ex2.approx.f(float %6830) #3, !dbg !294 + br label %__nv_exp2f.exit1408, !dbg !294 + +__nv_exp2f.exit1408: ; preds = %6975, %6977 + %.0.i1407 = phi float [ %6976, %6975 ], [ %6978, %6977 ], !dbg !294 + %6979 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1409 = icmp eq i32 %6979, 0, !dbg !294 + br i1 %.not.i1409, label %6982, label %6980, !dbg !294 + +6980: ; preds = %__nv_exp2f.exit1408 + %6981 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6831) #3, !dbg !294 + br label %__nv_exp2f.exit1411, !dbg !294 + +6982: ; preds = %__nv_exp2f.exit1408 + %6983 = tail call float @llvm.nvvm.ex2.approx.f(float %6831) #3, !dbg !294 + br label %__nv_exp2f.exit1411, !dbg !294 + +__nv_exp2f.exit1411: ; preds = %6980, %6982 + %.0.i1410 = phi float [ %6981, %6980 ], [ %6983, %6982 ], !dbg !294 + %6984 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1412 = icmp eq i32 %6984, 0, !dbg !294 + br i1 %.not.i1412, label %6987, label %6985, !dbg !294 + +6985: ; preds = %__nv_exp2f.exit1411 + %6986 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6832) #3, !dbg !294 + br label %__nv_exp2f.exit1414, !dbg !294 + +6987: ; preds = %__nv_exp2f.exit1411 + %6988 = tail call float @llvm.nvvm.ex2.approx.f(float %6832) #3, !dbg !294 + br label %__nv_exp2f.exit1414, !dbg !294 + +__nv_exp2f.exit1414: ; preds = %6985, %6987 + %.0.i1413 = phi float [ %6986, %6985 ], [ %6988, %6987 ], !dbg !294 + %6989 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !294 + %.not.i1415 = icmp eq i32 %6989, 0, !dbg !294 + br i1 %.not.i1415, label %6992, label %6990, !dbg !294 + +6990: ; preds = %__nv_exp2f.exit1414 + %6991 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6833) #3, !dbg !294 + br label %__nv_exp2f.exit1417, !dbg !294 + +6992: ; preds = %__nv_exp2f.exit1414 + %6993 = tail call float @llvm.nvvm.ex2.approx.f(float %6833) #3, !dbg !294 + br label %__nv_exp2f.exit1417, !dbg !294 + +__nv_exp2f.exit1417: ; preds = %6990, %6992 + %.0.i1416 = phi float [ %6991, %6990 ], [ %6993, %6992 ], !dbg !294 + %6994 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %6001, !dbg !271 + %6995 = insertelement <2 x float> poison, float %.0.i1323, i64 0, !dbg !295 + %6996 = insertelement <2 x float> %6995, float %.0.i1326, i64 1, !dbg !295 + %6997 = fptrunc <2 x float> %6996 to <2 x bfloat>, !dbg !295 + %6998 = insertelement <2 x float> poison, float %.0.i1329, i64 0, !dbg !295 + %6999 = insertelement <2 x float> %6998, float %.0.i1332, i64 1, !dbg !295 + %7000 = fptrunc <2 x float> %6999 to <2 x bfloat>, !dbg !295 + %7001 = insertelement <2 x float> poison, float %.0.i1335, i64 0, !dbg !295 + %7002 = insertelement <2 x float> %7001, float %.0.i1338, i64 1, !dbg !295 + %7003 = fptrunc <2 x float> %7002 to <2 x bfloat>, !dbg !295 + %7004 = insertelement <2 x float> poison, float %.0.i1341, i64 0, !dbg !295 + %7005 = insertelement <2 x float> %7004, float %.0.i1344, i64 1, !dbg !295 + %7006 = fptrunc <2 x float> %7005 to <2 x bfloat>, !dbg !295 + %7007 = insertelement <2 x float> poison, float %.0.i1347, i64 0, !dbg !295 + %7008 = insertelement <2 x float> %7007, float %.0.i1350, i64 1, !dbg !295 + %7009 = fptrunc <2 x float> %7008 to <2 x bfloat>, !dbg !295 + %7010 = insertelement <2 x float> poison, float %.0.i1353, i64 0, !dbg !295 + %7011 = insertelement <2 x float> %7010, float %.0.i1356, i64 1, !dbg !295 + %7012 = fptrunc <2 x float> %7011 to <2 x bfloat>, !dbg !295 + %7013 = insertelement <2 x float> poison, float %.0.i1359, i64 0, !dbg !295 + %7014 = insertelement <2 x float> %7013, float %.0.i1362, i64 1, !dbg !295 + %7015 = fptrunc <2 x float> %7014 to <2 x bfloat>, !dbg !295 + %7016 = insertelement <2 x float> poison, float %.0.i1365, i64 0, !dbg !295 + %7017 = insertelement <2 x float> %7016, float %.0.i1368, i64 1, !dbg !295 + %7018 = fptrunc <2 x float> %7017 to <2 x bfloat>, !dbg !295 + %7019 = insertelement <2 x float> poison, float %.0.i1371, i64 0, !dbg !295 + %7020 = insertelement <2 x float> %7019, float %.0.i1374, i64 1, !dbg !295 + %7021 = fptrunc <2 x float> %7020 to <2 x bfloat>, !dbg !295 + %7022 = insertelement <2 x float> poison, float %.0.i1377, i64 0, !dbg !295 + %7023 = insertelement <2 x float> %7022, float %.0.i1380, i64 1, !dbg !295 + %7024 = fptrunc <2 x float> %7023 to <2 x bfloat>, !dbg !295 + %7025 = insertelement <2 x float> poison, float %.0.i1383, i64 0, !dbg !295 + %7026 = insertelement <2 x float> %7025, float %.0.i1386, i64 1, !dbg !295 + %7027 = fptrunc <2 x float> %7026 to <2 x bfloat>, !dbg !295 + %7028 = insertelement <2 x float> poison, float %.0.i1389, i64 0, !dbg !295 + %7029 = insertelement <2 x float> %7028, float %.0.i1392, i64 1, !dbg !295 + %7030 = fptrunc <2 x float> %7029 to <2 x bfloat>, !dbg !295 + %7031 = insertelement <2 x float> poison, float %.0.i1395, i64 0, !dbg !295 + %7032 = insertelement <2 x float> %7031, float %.0.i1398, i64 1, !dbg !295 + %7033 = fptrunc <2 x float> %7032 to <2 x bfloat>, !dbg !295 + %7034 = insertelement <2 x float> poison, float %.0.i1401, i64 0, !dbg !295 + %7035 = insertelement <2 x float> %7034, float %.0.i1404, i64 1, !dbg !295 + %7036 = fptrunc <2 x float> %7035 to <2 x bfloat>, !dbg !295 + %7037 = insertelement <2 x float> poison, float %.0.i1407, i64 0, !dbg !295 + %7038 = insertelement <2 x float> %7037, float %.0.i1410, i64 1, !dbg !295 + %7039 = fptrunc <2 x float> %7038 to <2 x bfloat>, !dbg !295 + %7040 = insertelement <2 x float> poison, float %.0.i1413, i64 0, !dbg !295 + %7041 = insertelement <2 x float> %7040, float %.0.i1416, i64 1, !dbg !295 + %7042 = fptrunc <2 x float> %7041 to <2 x bfloat>, !dbg !295 + %7043 = bitcast <2 x bfloat> %6997 to i32, !dbg !296 + %7044 = bitcast <2 x bfloat> %7000 to i32, !dbg !296 + %7045 = bitcast <2 x bfloat> %7003 to i32, !dbg !296 + %7046 = bitcast <2 x bfloat> %7006 to i32, !dbg !296 + %7047 = bitcast <2 x bfloat> %7009 to i32, !dbg !296 + %7048 = bitcast <2 x bfloat> %7012 to i32, !dbg !296 + %7049 = bitcast <2 x bfloat> %7015 to i32, !dbg !296 + %7050 = bitcast <2 x bfloat> %7018 to i32, !dbg !296 + %7051 = bitcast <2 x bfloat> %7021 to i32, !dbg !296 + %7052 = bitcast <2 x bfloat> %7024 to i32, !dbg !296 + %7053 = bitcast <2 x bfloat> %7027 to i32, !dbg !296 + %7054 = bitcast <2 x bfloat> %7030 to i32, !dbg !296 + %7055 = bitcast <2 x bfloat> %7033 to i32, !dbg !296 + %7056 = bitcast <2 x bfloat> %7036 to i32, !dbg !296 + %7057 = bitcast <2 x bfloat> %7039 to i32, !dbg !296 + %7058 = bitcast <2 x bfloat> %7042 to i32, !dbg !296 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !296 + %7059 = ptrtoint ptr addrspace(3) %6994 to i32, !dbg !296 + %7060 = lshr exact i32 %7059, 4, !dbg !296 + %7061 = and i32 %7060, 16383, !dbg !296 + %7062 = zext nneg i32 %7061 to i64, !dbg !296 + %7063 = or disjoint i64 %7062, 4611686293338849280, !dbg !296 + %7064 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5847, float %5848, float %5849, float %5850, float %5851, float %5852, float %5853, float %5854, float %5855, float %5856, float %5857, float %5858, float %5859, float %5860, float %5861, float %5862, float %5863, float %5864, float %5865, float %5866, float %5867, float %5868, float %5869, float %5870, float %5871, float %5872, float %5873, float %5874, float %5875, float %5876, float %5877, float %5878, float %5879, float %5880, float %5881, float %5882, float %5883, float %5884, float %5885, float %5886, float %5887, float %5888, float %5889, float %5890, float %5891, float %5892, float %5893, float %5894, float %5895, float %5896, float %5897, float %5898, float %5899, float %5900, float %5901, float %5902, float %5903, float %5904, float %5905, float %5906, float %5907, float %5908, float %5909, float %5910, i32 %7043, i32 %7044, i32 %7045, i32 %7046, i64 %7063, i1 true) #3, !dbg !296 + %7065 = add i32 %7059, 2048, !dbg !296 + %7066 = lshr exact i32 %7065, 4, !dbg !296 + %7067 = and i32 %7066, 16383, !dbg !296 + %7068 = zext nneg i32 %7067 to i64, !dbg !296 + %7069 = or disjoint i64 %7068, 4611686293338849280, !dbg !296 + %7070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 0, !dbg !296 + %7071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 1, !dbg !296 + %7072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 2, !dbg !296 + %7073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 3, !dbg !296 + %7074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 4, !dbg !296 + %7075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 5, !dbg !296 + %7076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 6, !dbg !296 + %7077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 7, !dbg !296 + %7078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 8, !dbg !296 + %7079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 9, !dbg !296 + %7080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 10, !dbg !296 + %7081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 11, !dbg !296 + %7082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 12, !dbg !296 + %7083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 13, !dbg !296 + %7084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 14, !dbg !296 + %7085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 15, !dbg !296 + %7086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 16, !dbg !296 + %7087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 17, !dbg !296 + %7088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 18, !dbg !296 + %7089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 19, !dbg !296 + %7090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 20, !dbg !296 + %7091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 21, !dbg !296 + %7092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 22, !dbg !296 + %7093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 23, !dbg !296 + %7094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 24, !dbg !296 + %7095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 25, !dbg !296 + %7096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 26, !dbg !296 + %7097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 27, !dbg !296 + %7098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 28, !dbg !296 + %7099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 29, !dbg !296 + %7100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 30, !dbg !296 + %7101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 31, !dbg !296 + %7102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 32, !dbg !296 + %7103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 33, !dbg !296 + %7104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 34, !dbg !296 + %7105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 35, !dbg !296 + %7106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 36, !dbg !296 + %7107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 37, !dbg !296 + %7108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 38, !dbg !296 + %7109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 39, !dbg !296 + %7110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 40, !dbg !296 + %7111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 41, !dbg !296 + %7112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 42, !dbg !296 + %7113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 43, !dbg !296 + %7114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 44, !dbg !296 + %7115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 45, !dbg !296 + %7116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 46, !dbg !296 + %7117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 47, !dbg !296 + %7118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 48, !dbg !296 + %7119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 49, !dbg !296 + %7120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 50, !dbg !296 + %7121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 51, !dbg !296 + %7122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 52, !dbg !296 + %7123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 53, !dbg !296 + %7124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 54, !dbg !296 + %7125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 55, !dbg !296 + %7126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 56, !dbg !296 + %7127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 57, !dbg !296 + %7128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 58, !dbg !296 + %7129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 59, !dbg !296 + %7130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 60, !dbg !296 + %7131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 61, !dbg !296 + %7132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 62, !dbg !296 + %7133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7064, 63, !dbg !296 + %7134 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7070, float %7071, float %7072, float %7073, float %7074, float %7075, float %7076, float %7077, float %7078, float %7079, float %7080, float %7081, float %7082, float %7083, float %7084, float %7085, float %7086, float %7087, float %7088, float %7089, float %7090, float %7091, float %7092, float %7093, float %7094, float %7095, float %7096, float %7097, float %7098, float %7099, float %7100, float %7101, float %7102, float %7103, float %7104, float %7105, float %7106, float %7107, float %7108, float %7109, float %7110, float %7111, float %7112, float %7113, float %7114, float %7115, float %7116, float %7117, float %7118, float %7119, float %7120, float %7121, float %7122, float %7123, float %7124, float %7125, float %7126, float %7127, float %7128, float %7129, float %7130, float %7131, float %7132, float %7133, i32 %7047, i32 %7048, i32 %7049, i32 %7050, i64 %7069, i1 true) #3, !dbg !296 + %7135 = add i32 %7059, 4096, !dbg !296 + %7136 = lshr exact i32 %7135, 4, !dbg !296 + %7137 = and i32 %7136, 16383, !dbg !296 + %7138 = zext nneg i32 %7137 to i64, !dbg !296 + %7139 = or disjoint i64 %7138, 4611686293338849280, !dbg !296 + %7140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 0, !dbg !296 + %7141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 1, !dbg !296 + %7142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 2, !dbg !296 + %7143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 3, !dbg !296 + %7144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 4, !dbg !296 + %7145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 5, !dbg !296 + %7146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 6, !dbg !296 + %7147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 7, !dbg !296 + %7148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 8, !dbg !296 + %7149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 9, !dbg !296 + %7150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 10, !dbg !296 + %7151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 11, !dbg !296 + %7152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 12, !dbg !296 + %7153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 13, !dbg !296 + %7154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 14, !dbg !296 + %7155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 15, !dbg !296 + %7156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 16, !dbg !296 + %7157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 17, !dbg !296 + %7158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 18, !dbg !296 + %7159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 19, !dbg !296 + %7160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 20, !dbg !296 + %7161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 21, !dbg !296 + %7162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 22, !dbg !296 + %7163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 23, !dbg !296 + %7164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 24, !dbg !296 + %7165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 25, !dbg !296 + %7166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 26, !dbg !296 + %7167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 27, !dbg !296 + %7168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 28, !dbg !296 + %7169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 29, !dbg !296 + %7170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 30, !dbg !296 + %7171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 31, !dbg !296 + %7172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 32, !dbg !296 + %7173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 33, !dbg !296 + %7174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 34, !dbg !296 + %7175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 35, !dbg !296 + %7176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 36, !dbg !296 + %7177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 37, !dbg !296 + %7178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 38, !dbg !296 + %7179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 39, !dbg !296 + %7180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 40, !dbg !296 + %7181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 41, !dbg !296 + %7182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 42, !dbg !296 + %7183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 43, !dbg !296 + %7184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 44, !dbg !296 + %7185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 45, !dbg !296 + %7186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 46, !dbg !296 + %7187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 47, !dbg !296 + %7188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 48, !dbg !296 + %7189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 49, !dbg !296 + %7190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 50, !dbg !296 + %7191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 51, !dbg !296 + %7192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 52, !dbg !296 + %7193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 53, !dbg !296 + %7194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 54, !dbg !296 + %7195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 55, !dbg !296 + %7196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 56, !dbg !296 + %7197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 57, !dbg !296 + %7198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 58, !dbg !296 + %7199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 59, !dbg !296 + %7200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 60, !dbg !296 + %7201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 61, !dbg !296 + %7202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 62, !dbg !296 + %7203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7134, 63, !dbg !296 + %7204 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7140, float %7141, float %7142, float %7143, float %7144, float %7145, float %7146, float %7147, float %7148, float %7149, float %7150, float %7151, float %7152, float %7153, float %7154, float %7155, float %7156, float %7157, float %7158, float %7159, float %7160, float %7161, float %7162, float %7163, float %7164, float %7165, float %7166, float %7167, float %7168, float %7169, float %7170, float %7171, float %7172, float %7173, float %7174, float %7175, float %7176, float %7177, float %7178, float %7179, float %7180, float %7181, float %7182, float %7183, float %7184, float %7185, float %7186, float %7187, float %7188, float %7189, float %7190, float %7191, float %7192, float %7193, float %7194, float %7195, float %7196, float %7197, float %7198, float %7199, float %7200, float %7201, float %7202, float %7203, i32 %7051, i32 %7052, i32 %7053, i32 %7054, i64 %7139, i1 true) #3, !dbg !296 + %7205 = add i32 %7059, 6144, !dbg !296 + %7206 = lshr exact i32 %7205, 4, !dbg !296 + %7207 = and i32 %7206, 16383, !dbg !296 + %7208 = zext nneg i32 %7207 to i64, !dbg !296 + %7209 = or disjoint i64 %7208, 4611686293338849280, !dbg !296 + %7210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 0, !dbg !296 + %7211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 1, !dbg !296 + %7212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 2, !dbg !296 + %7213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 3, !dbg !296 + %7214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 4, !dbg !296 + %7215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 5, !dbg !296 + %7216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 6, !dbg !296 + %7217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 7, !dbg !296 + %7218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 8, !dbg !296 + %7219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 9, !dbg !296 + %7220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 10, !dbg !296 + %7221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 11, !dbg !296 + %7222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 12, !dbg !296 + %7223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 13, !dbg !296 + %7224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 14, !dbg !296 + %7225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 15, !dbg !296 + %7226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 16, !dbg !296 + %7227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 17, !dbg !296 + %7228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 18, !dbg !296 + %7229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 19, !dbg !296 + %7230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 20, !dbg !296 + %7231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 21, !dbg !296 + %7232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 22, !dbg !296 + %7233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 23, !dbg !296 + %7234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 24, !dbg !296 + %7235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 25, !dbg !296 + %7236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 26, !dbg !296 + %7237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 27, !dbg !296 + %7238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 28, !dbg !296 + %7239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 29, !dbg !296 + %7240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 30, !dbg !296 + %7241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 31, !dbg !296 + %7242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 32, !dbg !296 + %7243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 33, !dbg !296 + %7244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 34, !dbg !296 + %7245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 35, !dbg !296 + %7246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 36, !dbg !296 + %7247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 37, !dbg !296 + %7248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 38, !dbg !296 + %7249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 39, !dbg !296 + %7250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 40, !dbg !296 + %7251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 41, !dbg !296 + %7252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 42, !dbg !296 + %7253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 43, !dbg !296 + %7254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 44, !dbg !296 + %7255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 45, !dbg !296 + %7256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 46, !dbg !296 + %7257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 47, !dbg !296 + %7258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 48, !dbg !296 + %7259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 49, !dbg !296 + %7260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 50, !dbg !296 + %7261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 51, !dbg !296 + %7262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 52, !dbg !296 + %7263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 53, !dbg !296 + %7264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 54, !dbg !296 + %7265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 55, !dbg !296 + %7266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 56, !dbg !296 + %7267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 57, !dbg !296 + %7268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 58, !dbg !296 + %7269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 59, !dbg !296 + %7270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 60, !dbg !296 + %7271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 61, !dbg !296 + %7272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 62, !dbg !296 + %7273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7204, 63, !dbg !296 + %7274 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7210, float %7211, float %7212, float %7213, float %7214, float %7215, float %7216, float %7217, float %7218, float %7219, float %7220, float %7221, float %7222, float %7223, float %7224, float %7225, float %7226, float %7227, float %7228, float %7229, float %7230, float %7231, float %7232, float %7233, float %7234, float %7235, float %7236, float %7237, float %7238, float %7239, float %7240, float %7241, float %7242, float %7243, float %7244, float %7245, float %7246, float %7247, float %7248, float %7249, float %7250, float %7251, float %7252, float %7253, float %7254, float %7255, float %7256, float %7257, float %7258, float %7259, float %7260, float %7261, float %7262, float %7263, float %7264, float %7265, float %7266, float %7267, float %7268, float %7269, float %7270, float %7271, float %7272, float %7273, i32 %7055, i32 %7056, i32 %7057, i32 %7058, i64 %7209, i1 true) #3, !dbg !296 + %7275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 0, !dbg !296 + %7276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 1, !dbg !296 + %7277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 2, !dbg !296 + %7278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 3, !dbg !296 + %7279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 4, !dbg !296 + %7280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 5, !dbg !296 + %7281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 6, !dbg !296 + %7282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 7, !dbg !296 + %7283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 8, !dbg !296 + %7284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 9, !dbg !296 + %7285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 10, !dbg !296 + %7286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 11, !dbg !296 + %7287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 12, !dbg !296 + %7288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 13, !dbg !296 + %7289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 14, !dbg !296 + %7290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 15, !dbg !296 + %7291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 16, !dbg !296 + %7292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 17, !dbg !296 + %7293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 18, !dbg !296 + %7294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 19, !dbg !296 + %7295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 20, !dbg !296 + %7296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 21, !dbg !296 + %7297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 22, !dbg !296 + %7298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 23, !dbg !296 + %7299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 24, !dbg !296 + %7300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 25, !dbg !296 + %7301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 26, !dbg !296 + %7302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 27, !dbg !296 + %7303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 28, !dbg !296 + %7304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 29, !dbg !296 + %7305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 30, !dbg !296 + %7306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 31, !dbg !296 + %7307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 32, !dbg !296 + %7308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 33, !dbg !296 + %7309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 34, !dbg !296 + %7310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 35, !dbg !296 + %7311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 36, !dbg !296 + %7312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 37, !dbg !296 + %7313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 38, !dbg !296 + %7314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 39, !dbg !296 + %7315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 40, !dbg !296 + %7316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 41, !dbg !296 + %7317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 42, !dbg !296 + %7318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 43, !dbg !296 + %7319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 44, !dbg !296 + %7320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 45, !dbg !296 + %7321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 46, !dbg !296 + %7322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 47, !dbg !296 + %7323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 48, !dbg !296 + %7324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 49, !dbg !296 + %7325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 50, !dbg !296 + %7326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 51, !dbg !296 + %7327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 52, !dbg !296 + %7328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 53, !dbg !296 + %7329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 54, !dbg !296 + %7330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 55, !dbg !296 + %7331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 56, !dbg !296 + %7332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 57, !dbg !296 + %7333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 58, !dbg !296 + %7334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 59, !dbg !296 + %7335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 60, !dbg !296 + %7336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 61, !dbg !296 + %7337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 62, !dbg !296 + %7338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7274, 63, !dbg !296 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !296 + %7339 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %6003, !dbg !273 + %7340 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5189, !dbg !273 + %7341 = load float, ptr addrspace(3) %7340, align 8, !dbg !273 + %7342 = getelementptr inbounds nuw i8, ptr addrspace(3) %7340, i32 4, !dbg !273 + %7343 = load float, ptr addrspace(3) %7342, align 4, !dbg !273 + %7344 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5195, !dbg !273 + %7345 = load float, ptr addrspace(3) %7344, align 8, !dbg !273 + %7346 = getelementptr inbounds nuw i8, ptr addrspace(3) %7344, i32 4, !dbg !273 + %7347 = load float, ptr addrspace(3) %7346, align 4, !dbg !273 + %7348 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5201, !dbg !273 + %7349 = load float, ptr addrspace(3) %7348, align 8, !dbg !273 + %7350 = getelementptr inbounds nuw i8, ptr addrspace(3) %7348, i32 4, !dbg !273 + %7351 = load float, ptr addrspace(3) %7350, align 4, !dbg !273 + %7352 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5207, !dbg !273 + %7353 = load float, ptr addrspace(3) %7352, align 8, !dbg !273 + %7354 = getelementptr inbounds nuw i8, ptr addrspace(3) %7352, i32 4, !dbg !273 + %7355 = load float, ptr addrspace(3) %7354, align 4, !dbg !273 + %7356 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5213, !dbg !273 + %7357 = load float, ptr addrspace(3) %7356, align 8, !dbg !273 + %7358 = getelementptr inbounds nuw i8, ptr addrspace(3) %7356, i32 4, !dbg !273 + %7359 = load float, ptr addrspace(3) %7358, align 4, !dbg !273 + %7360 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5219, !dbg !273 + %7361 = load float, ptr addrspace(3) %7360, align 8, !dbg !273 + %7362 = getelementptr inbounds nuw i8, ptr addrspace(3) %7360, i32 4, !dbg !273 + %7363 = load float, ptr addrspace(3) %7362, align 4, !dbg !273 + %7364 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5225, !dbg !273 + %7365 = load float, ptr addrspace(3) %7364, align 8, !dbg !273 + %7366 = getelementptr inbounds nuw i8, ptr addrspace(3) %7364, i32 4, !dbg !273 + %7367 = load float, ptr addrspace(3) %7366, align 4, !dbg !273 + %7368 = getelementptr inbounds nuw i8, ptr addrspace(3) %7339, i32 %5231, !dbg !273 + %7369 = load float, ptr addrspace(3) %7368, align 8, !dbg !273 + %7370 = getelementptr inbounds nuw i8, ptr addrspace(3) %7368, i32 4, !dbg !273 + %7371 = load float, ptr addrspace(3) %7370, align 4, !dbg !273 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !297 + %7372 = add i32 %6071, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7373 = lshr exact i32 %7372, 4, !dbg !297 + %7374 = and i32 %7373, 16383, !dbg !297 + %7375 = zext nneg i32 %7374 to i64, !dbg !297 + %7376 = or disjoint i64 %7375, 4611686293372403712, !dbg !297 + %7377 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %7376, i64 %7063) #3, !dbg !297 + %7378 = add i32 %6083, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7379 = lshr exact i32 %7378, 4, !dbg !297 + %7380 = and i32 %7379, 16383, !dbg !297 + %7381 = zext nneg i32 %7380 to i64, !dbg !297 + %7382 = or disjoint i64 %7381, 4611686293372403712, !dbg !297 + %7383 = add i32 %7059, 32, !dbg !297 + %7384 = lshr exact i32 %7383, 4, !dbg !297 + %7385 = and i32 %7384, 16383, !dbg !297 + %7386 = zext nneg i32 %7385 to i64, !dbg !297 + %7387 = or disjoint i64 %7386, 4611686293338849280, !dbg !297 + %7388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 0, !dbg !297 + %7389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 1, !dbg !297 + %7390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 2, !dbg !297 + %7391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 3, !dbg !297 + %7392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 4, !dbg !297 + %7393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 5, !dbg !297 + %7394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 6, !dbg !297 + %7395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 7, !dbg !297 + %7396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 8, !dbg !297 + %7397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 9, !dbg !297 + %7398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 10, !dbg !297 + %7399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 11, !dbg !297 + %7400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 12, !dbg !297 + %7401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 13, !dbg !297 + %7402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 14, !dbg !297 + %7403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 15, !dbg !297 + %7404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 16, !dbg !297 + %7405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 17, !dbg !297 + %7406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 18, !dbg !297 + %7407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 19, !dbg !297 + %7408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 20, !dbg !297 + %7409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 21, !dbg !297 + %7410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 22, !dbg !297 + %7411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 23, !dbg !297 + %7412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 24, !dbg !297 + %7413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 25, !dbg !297 + %7414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 26, !dbg !297 + %7415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 27, !dbg !297 + %7416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 28, !dbg !297 + %7417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 29, !dbg !297 + %7418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 30, !dbg !297 + %7419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7377, 31, !dbg !297 + %7420 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7388, float %7389, float %7390, float %7391, float %7392, float %7393, float %7394, float %7395, float %7396, float %7397, float %7398, float %7399, float %7400, float %7401, float %7402, float %7403, float %7404, float %7405, float %7406, float %7407, float %7408, float %7409, float %7410, float %7411, float %7412, float %7413, float %7414, float %7415, float %7416, float %7417, float %7418, float %7419, i64 %7382, i64 %7387, i1 true) #3, !dbg !297 + %7421 = add i32 %6127, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7422 = lshr exact i32 %7421, 4, !dbg !297 + %7423 = and i32 %7422, 16383, !dbg !297 + %7424 = zext nneg i32 %7423 to i64, !dbg !297 + %7425 = or disjoint i64 %7424, 4611686293372403712, !dbg !297 + %7426 = add i32 %7059, 64, !dbg !297 + %7427 = lshr exact i32 %7426, 4, !dbg !297 + %7428 = and i32 %7427, 16383, !dbg !297 + %7429 = zext nneg i32 %7428 to i64, !dbg !297 + %7430 = or disjoint i64 %7429, 4611686293338849280, !dbg !297 + %7431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 0, !dbg !297 + %7432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 1, !dbg !297 + %7433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 2, !dbg !297 + %7434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 3, !dbg !297 + %7435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 4, !dbg !297 + %7436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 5, !dbg !297 + %7437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 6, !dbg !297 + %7438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 7, !dbg !297 + %7439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 8, !dbg !297 + %7440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 9, !dbg !297 + %7441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 10, !dbg !297 + %7442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 11, !dbg !297 + %7443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 12, !dbg !297 + %7444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 13, !dbg !297 + %7445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 14, !dbg !297 + %7446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 15, !dbg !297 + %7447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 16, !dbg !297 + %7448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 17, !dbg !297 + %7449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 18, !dbg !297 + %7450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 19, !dbg !297 + %7451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 20, !dbg !297 + %7452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 21, !dbg !297 + %7453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 22, !dbg !297 + %7454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 23, !dbg !297 + %7455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 24, !dbg !297 + %7456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 25, !dbg !297 + %7457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 26, !dbg !297 + %7458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 27, !dbg !297 + %7459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 28, !dbg !297 + %7460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 29, !dbg !297 + %7461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 30, !dbg !297 + %7462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7420, 31, !dbg !297 + %7463 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7431, float %7432, float %7433, float %7434, float %7435, float %7436, float %7437, float %7438, float %7439, float %7440, float %7441, float %7442, float %7443, float %7444, float %7445, float %7446, float %7447, float %7448, float %7449, float %7450, float %7451, float %7452, float %7453, float %7454, float %7455, float %7456, float %7457, float %7458, float %7459, float %7460, float %7461, float %7462, i64 %7425, i64 %7430, i1 true) #3, !dbg !297 + %7464 = add i32 %6171, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7465 = lshr exact i32 %7464, 4, !dbg !297 + %7466 = and i32 %7465, 16383, !dbg !297 + %7467 = zext nneg i32 %7466 to i64, !dbg !297 + %7468 = or disjoint i64 %7467, 4611686293372403712, !dbg !297 + %7469 = add i32 %7059, 96, !dbg !297 + %7470 = lshr exact i32 %7469, 4, !dbg !297 + %7471 = and i32 %7470, 16383, !dbg !297 + %7472 = zext nneg i32 %7471 to i64, !dbg !297 + %7473 = or disjoint i64 %7472, 4611686293338849280, !dbg !297 + %7474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 0, !dbg !297 + %7475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 1, !dbg !297 + %7476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 2, !dbg !297 + %7477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 3, !dbg !297 + %7478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 4, !dbg !297 + %7479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 5, !dbg !297 + %7480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 6, !dbg !297 + %7481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 7, !dbg !297 + %7482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 8, !dbg !297 + %7483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 9, !dbg !297 + %7484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 10, !dbg !297 + %7485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 11, !dbg !297 + %7486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 12, !dbg !297 + %7487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 13, !dbg !297 + %7488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 14, !dbg !297 + %7489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 15, !dbg !297 + %7490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 16, !dbg !297 + %7491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 17, !dbg !297 + %7492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 18, !dbg !297 + %7493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 19, !dbg !297 + %7494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 20, !dbg !297 + %7495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 21, !dbg !297 + %7496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 22, !dbg !297 + %7497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 23, !dbg !297 + %7498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 24, !dbg !297 + %7499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 25, !dbg !297 + %7500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 26, !dbg !297 + %7501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 27, !dbg !297 + %7502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 28, !dbg !297 + %7503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 29, !dbg !297 + %7504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 30, !dbg !297 + %7505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7463, 31, !dbg !297 + %7506 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7474, float %7475, float %7476, float %7477, float %7478, float %7479, float %7480, float %7481, float %7482, float %7483, float %7484, float %7485, float %7486, float %7487, float %7488, float %7489, float %7490, float %7491, float %7492, float %7493, float %7494, float %7495, float %7496, float %7497, float %7498, float %7499, float %7500, float %7501, float %7502, float %7503, float %7504, float %7505, i64 %7468, i64 %7473, i1 true) #3, !dbg !297 + %7507 = add i32 %6215, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7508 = lshr exact i32 %7507, 4, !dbg !297 + %7509 = and i32 %7508, 16383, !dbg !297 + %7510 = zext nneg i32 %7509 to i64, !dbg !297 + %7511 = or disjoint i64 %7510, 4611686293372403712, !dbg !297 + %7512 = add i32 %7059, 8192, !dbg !297 + %7513 = lshr exact i32 %7512, 4, !dbg !297 + %7514 = and i32 %7513, 16383, !dbg !297 + %7515 = zext nneg i32 %7514 to i64, !dbg !297 + %7516 = or disjoint i64 %7515, 4611686293338849280, !dbg !297 + %7517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 0, !dbg !297 + %7518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 1, !dbg !297 + %7519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 2, !dbg !297 + %7520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 3, !dbg !297 + %7521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 4, !dbg !297 + %7522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 5, !dbg !297 + %7523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 6, !dbg !297 + %7524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 7, !dbg !297 + %7525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 8, !dbg !297 + %7526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 9, !dbg !297 + %7527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 10, !dbg !297 + %7528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 11, !dbg !297 + %7529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 12, !dbg !297 + %7530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 13, !dbg !297 + %7531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 14, !dbg !297 + %7532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 15, !dbg !297 + %7533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 16, !dbg !297 + %7534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 17, !dbg !297 + %7535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 18, !dbg !297 + %7536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 19, !dbg !297 + %7537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 20, !dbg !297 + %7538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 21, !dbg !297 + %7539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 22, !dbg !297 + %7540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 23, !dbg !297 + %7541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 24, !dbg !297 + %7542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 25, !dbg !297 + %7543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 26, !dbg !297 + %7544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 27, !dbg !297 + %7545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 28, !dbg !297 + %7546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 29, !dbg !297 + %7547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 30, !dbg !297 + %7548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7506, 31, !dbg !297 + %7549 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7517, float %7518, float %7519, float %7520, float %7521, float %7522, float %7523, float %7524, float %7525, float %7526, float %7527, float %7528, float %7529, float %7530, float %7531, float %7532, float %7533, float %7534, float %7535, float %7536, float %7537, float %7538, float %7539, float %7540, float %7541, float %7542, float %7543, float %7544, float %7545, float %7546, float %7547, float %7548, i64 %7511, i64 %7516, i1 true) #3, !dbg !297 + %7550 = add i32 %6259, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7551 = lshr exact i32 %7550, 4, !dbg !297 + %7552 = and i32 %7551, 16383, !dbg !297 + %7553 = zext nneg i32 %7552 to i64, !dbg !297 + %7554 = or disjoint i64 %7553, 4611686293372403712, !dbg !297 + %7555 = add i32 %7059, 8224, !dbg !297 + %7556 = lshr exact i32 %7555, 4, !dbg !297 + %7557 = and i32 %7556, 16383, !dbg !297 + %7558 = zext nneg i32 %7557 to i64, !dbg !297 + %7559 = or disjoint i64 %7558, 4611686293338849280, !dbg !297 + %7560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 0, !dbg !297 + %7561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 1, !dbg !297 + %7562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 2, !dbg !297 + %7563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 3, !dbg !297 + %7564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 4, !dbg !297 + %7565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 5, !dbg !297 + %7566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 6, !dbg !297 + %7567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 7, !dbg !297 + %7568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 8, !dbg !297 + %7569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 9, !dbg !297 + %7570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 10, !dbg !297 + %7571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 11, !dbg !297 + %7572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 12, !dbg !297 + %7573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 13, !dbg !297 + %7574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 14, !dbg !297 + %7575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 15, !dbg !297 + %7576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 16, !dbg !297 + %7577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 17, !dbg !297 + %7578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 18, !dbg !297 + %7579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 19, !dbg !297 + %7580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 20, !dbg !297 + %7581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 21, !dbg !297 + %7582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 22, !dbg !297 + %7583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 23, !dbg !297 + %7584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 24, !dbg !297 + %7585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 25, !dbg !297 + %7586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 26, !dbg !297 + %7587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 27, !dbg !297 + %7588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 28, !dbg !297 + %7589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 29, !dbg !297 + %7590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 30, !dbg !297 + %7591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7549, 31, !dbg !297 + %7592 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7560, float %7561, float %7562, float %7563, float %7564, float %7565, float %7566, float %7567, float %7568, float %7569, float %7570, float %7571, float %7572, float %7573, float %7574, float %7575, float %7576, float %7577, float %7578, float %7579, float %7580, float %7581, float %7582, float %7583, float %7584, float %7585, float %7586, float %7587, float %7588, float %7589, float %7590, float %7591, i64 %7554, i64 %7559, i1 true) #3, !dbg !297 + %7593 = add i32 %6303, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7594 = lshr exact i32 %7593, 4, !dbg !297 + %7595 = and i32 %7594, 16383, !dbg !297 + %7596 = zext nneg i32 %7595 to i64, !dbg !297 + %7597 = or disjoint i64 %7596, 4611686293372403712, !dbg !297 + %7598 = add i32 %7059, 8256, !dbg !297 + %7599 = lshr exact i32 %7598, 4, !dbg !297 + %7600 = and i32 %7599, 16383, !dbg !297 + %7601 = zext nneg i32 %7600 to i64, !dbg !297 + %7602 = or disjoint i64 %7601, 4611686293338849280, !dbg !297 + %7603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 0, !dbg !297 + %7604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 1, !dbg !297 + %7605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 2, !dbg !297 + %7606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 3, !dbg !297 + %7607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 4, !dbg !297 + %7608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 5, !dbg !297 + %7609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 6, !dbg !297 + %7610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 7, !dbg !297 + %7611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 8, !dbg !297 + %7612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 9, !dbg !297 + %7613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 10, !dbg !297 + %7614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 11, !dbg !297 + %7615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 12, !dbg !297 + %7616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 13, !dbg !297 + %7617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 14, !dbg !297 + %7618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 15, !dbg !297 + %7619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 16, !dbg !297 + %7620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 17, !dbg !297 + %7621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 18, !dbg !297 + %7622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 19, !dbg !297 + %7623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 20, !dbg !297 + %7624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 21, !dbg !297 + %7625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 22, !dbg !297 + %7626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 23, !dbg !297 + %7627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 24, !dbg !297 + %7628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 25, !dbg !297 + %7629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 26, !dbg !297 + %7630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 27, !dbg !297 + %7631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 28, !dbg !297 + %7632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 29, !dbg !297 + %7633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 30, !dbg !297 + %7634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7592, 31, !dbg !297 + %7635 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7603, float %7604, float %7605, float %7606, float %7607, float %7608, float %7609, float %7610, float %7611, float %7612, float %7613, float %7614, float %7615, float %7616, float %7617, float %7618, float %7619, float %7620, float %7621, float %7622, float %7623, float %7624, float %7625, float %7626, float %7627, float %7628, float %7629, float %7630, float %7631, float %7632, float %7633, float %7634, i64 %7597, i64 %7602, i1 true) #3, !dbg !297 + %7636 = add i32 %6347, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !297 + %7637 = lshr exact i32 %7636, 4, !dbg !297 + %7638 = and i32 %7637, 16383, !dbg !297 + %7639 = zext nneg i32 %7638 to i64, !dbg !297 + %7640 = or disjoint i64 %7639, 4611686293372403712, !dbg !297 + %7641 = add i32 %7059, 8288, !dbg !297 + %7642 = lshr exact i32 %7641, 4, !dbg !297 + %7643 = and i32 %7642, 16383, !dbg !297 + %7644 = zext nneg i32 %7643 to i64, !dbg !297 + %7645 = or disjoint i64 %7644, 4611686293338849280, !dbg !297 + %7646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 0, !dbg !297 + %7647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 1, !dbg !297 + %7648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 2, !dbg !297 + %7649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 3, !dbg !297 + %7650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 4, !dbg !297 + %7651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 5, !dbg !297 + %7652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 6, !dbg !297 + %7653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 7, !dbg !297 + %7654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 8, !dbg !297 + %7655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 9, !dbg !297 + %7656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 10, !dbg !297 + %7657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 11, !dbg !297 + %7658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 12, !dbg !297 + %7659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 13, !dbg !297 + %7660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 14, !dbg !297 + %7661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 15, !dbg !297 + %7662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 16, !dbg !297 + %7663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 17, !dbg !297 + %7664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 18, !dbg !297 + %7665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 19, !dbg !297 + %7666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 20, !dbg !297 + %7667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 21, !dbg !297 + %7668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 22, !dbg !297 + %7669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 23, !dbg !297 + %7670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 24, !dbg !297 + %7671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 25, !dbg !297 + %7672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 26, !dbg !297 + %7673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 27, !dbg !297 + %7674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 28, !dbg !297 + %7675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 29, !dbg !297 + %7676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 30, !dbg !297 + %7677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7635, 31, !dbg !297 + %7678 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7646, float %7647, float %7648, float %7649, float %7650, float %7651, float %7652, float %7653, float %7654, float %7655, float %7656, float %7657, float %7658, float %7659, float %7660, float %7661, float %7662, float %7663, float %7664, float %7665, float %7666, float %7667, float %7668, float %7669, float %7670, float %7671, float %7672, float %7673, float %7674, float %7675, float %7676, float %7677, i64 %7640, i64 %7645, i1 true) #3, !dbg !297 + %7679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 0, !dbg !297 + %7680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 1, !dbg !297 + %7681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 2, !dbg !297 + %7682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 3, !dbg !297 + %7683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 4, !dbg !297 + %7684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 5, !dbg !297 + %7685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 6, !dbg !297 + %7686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 7, !dbg !297 + %7687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 8, !dbg !297 + %7688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 9, !dbg !297 + %7689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 10, !dbg !297 + %7690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 11, !dbg !297 + %7691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 12, !dbg !297 + %7692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 13, !dbg !297 + %7693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 14, !dbg !297 + %7694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 15, !dbg !297 + %7695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 16, !dbg !297 + %7696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 17, !dbg !297 + %7697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 18, !dbg !297 + %7698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 19, !dbg !297 + %7699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 20, !dbg !297 + %7700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 21, !dbg !297 + %7701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 22, !dbg !297 + %7702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 23, !dbg !297 + %7703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 24, !dbg !297 + %7704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 25, !dbg !297 + %7705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 26, !dbg !297 + %7706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 27, !dbg !297 + %7707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 28, !dbg !297 + %7708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 29, !dbg !297 + %7709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 30, !dbg !297 + %7710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7678, 31, !dbg !297 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !297 + %7711 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %7679, float %7680, float %7681, float %7682, float %7683, float %7684, float %7685, float %7686, float %7687, float %7688, float %7689, float %7690, float %7691, float %7692, float %7693, float %7694, float %7695, float %7696, float %7697, float %7698, float %7699, float %7700, float %7701, float %7702, float %7703, float %7704, float %7705, float %7706, float %7707, float %7708, float %7709, float %7710, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %6994, i32 0, i32 0) #3, !dbg !297 + %7712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 0, !dbg !297 + %7713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 1, !dbg !297 + %7714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 2, !dbg !297 + %7715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 3, !dbg !297 + %7716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 4, !dbg !297 + %7717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 5, !dbg !297 + %7718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 6, !dbg !297 + %7719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 7, !dbg !297 + %7720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 8, !dbg !297 + %7721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 9, !dbg !297 + %7722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 10, !dbg !297 + %7723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 11, !dbg !297 + %7724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 12, !dbg !297 + %7725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 13, !dbg !297 + %7726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 14, !dbg !297 + %7727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 15, !dbg !297 + %7728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 16, !dbg !297 + %7729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 17, !dbg !297 + %7730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 18, !dbg !297 + %7731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 19, !dbg !297 + %7732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 20, !dbg !297 + %7733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 21, !dbg !297 + %7734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 22, !dbg !297 + %7735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 23, !dbg !297 + %7736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 24, !dbg !297 + %7737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 25, !dbg !297 + %7738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 26, !dbg !297 + %7739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 27, !dbg !297 + %7740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 28, !dbg !297 + %7741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 29, !dbg !297 + %7742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 30, !dbg !297 + %7743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7711, 31, !dbg !297 + %7744 = fsub float %7712, %7341, !dbg !298 + %7745 = fsub float %7713, %7343, !dbg !298 + %7746 = fsub float %7714, %7341, !dbg !298 + %7747 = fsub float %7715, %7343, !dbg !298 + %7748 = fsub float %7716, %7345, !dbg !298 + %7749 = fsub float %7717, %7347, !dbg !298 + %7750 = fsub float %7718, %7345, !dbg !298 + %7751 = fsub float %7719, %7347, !dbg !298 + %7752 = fsub float %7720, %7349, !dbg !298 + %7753 = fsub float %7721, %7351, !dbg !298 + %7754 = fsub float %7722, %7349, !dbg !298 + %7755 = fsub float %7723, %7351, !dbg !298 + %7756 = fsub float %7724, %7353, !dbg !298 + %7757 = fsub float %7725, %7355, !dbg !298 + %7758 = fsub float %7726, %7353, !dbg !298 + %7759 = fsub float %7727, %7355, !dbg !298 + %7760 = fsub float %7728, %7357, !dbg !298 + %7761 = fsub float %7729, %7359, !dbg !298 + %7762 = fsub float %7730, %7357, !dbg !298 + %7763 = fsub float %7731, %7359, !dbg !298 + %7764 = fsub float %7732, %7361, !dbg !298 + %7765 = fsub float %7733, %7363, !dbg !298 + %7766 = fsub float %7734, %7361, !dbg !298 + %7767 = fsub float %7735, %7363, !dbg !298 + %7768 = fsub float %7736, %7365, !dbg !298 + %7769 = fsub float %7737, %7367, !dbg !298 + %7770 = fsub float %7738, %7365, !dbg !298 + %7771 = fsub float %7739, %7367, !dbg !298 + %7772 = fsub float %7740, %7369, !dbg !298 + %7773 = fsub float %7741, %7371, !dbg !298 + %7774 = fsub float %7742, %7369, !dbg !298 + %7775 = fsub float %7743, %7371, !dbg !298 + %7776 = fmul float %.0.i1323, %7744, !dbg !299 + %7777 = fmul float %.0.i1326, %7745, !dbg !299 + %7778 = fmul float %.0.i1329, %7746, !dbg !299 + %7779 = fmul float %.0.i1332, %7747, !dbg !299 + %7780 = fmul float %.0.i1335, %7748, !dbg !299 + %7781 = fmul float %.0.i1338, %7749, !dbg !299 + %7782 = fmul float %.0.i1341, %7750, !dbg !299 + %7783 = fmul float %.0.i1344, %7751, !dbg !299 + %7784 = fmul float %.0.i1347, %7752, !dbg !299 + %7785 = fmul float %.0.i1350, %7753, !dbg !299 + %7786 = fmul float %.0.i1353, %7754, !dbg !299 + %7787 = fmul float %.0.i1356, %7755, !dbg !299 + %7788 = fmul float %.0.i1359, %7756, !dbg !299 + %7789 = fmul float %.0.i1362, %7757, !dbg !299 + %7790 = fmul float %.0.i1365, %7758, !dbg !299 + %7791 = fmul float %.0.i1368, %7759, !dbg !299 + %7792 = fmul float %.0.i1371, %7760, !dbg !299 + %7793 = fmul float %.0.i1374, %7761, !dbg !299 + %7794 = fmul float %.0.i1377, %7762, !dbg !299 + %7795 = fmul float %.0.i1380, %7763, !dbg !299 + %7796 = fmul float %.0.i1383, %7764, !dbg !299 + %7797 = fmul float %.0.i1386, %7765, !dbg !299 + %7798 = fmul float %.0.i1389, %7766, !dbg !299 + %7799 = fmul float %.0.i1392, %7767, !dbg !299 + %7800 = fmul float %.0.i1395, %7768, !dbg !299 + %7801 = fmul float %.0.i1398, %7769, !dbg !299 + %7802 = fmul float %.0.i1401, %7770, !dbg !299 + %7803 = fmul float %.0.i1404, %7771, !dbg !299 + %7804 = fmul float %.0.i1407, %7772, !dbg !299 + %7805 = fmul float %.0.i1410, %7773, !dbg !299 + %7806 = fmul float %.0.i1413, %7774, !dbg !299 + %7807 = fmul float %.0.i1416, %7775, !dbg !299 + %7808 = fptrunc float %7776 to bfloat, !dbg !300 + %7809 = select i1 %6706, bfloat %7808, bfloat 0xR0000, !dbg !301 + %7810 = fptrunc float %7777 to bfloat, !dbg !300 + %7811 = select i1 %6707, bfloat %7810, bfloat 0xR0000, !dbg !301 + %7812 = fptrunc float %7778 to bfloat, !dbg !300 + %7813 = select i1 %6708, bfloat %7812, bfloat 0xR0000, !dbg !301 + %7814 = fptrunc float %7779 to bfloat, !dbg !300 + %7815 = select i1 %6709, bfloat %7814, bfloat 0xR0000, !dbg !301 + %7816 = fptrunc float %7780 to bfloat, !dbg !300 + %7817 = select i1 %6710, bfloat %7816, bfloat 0xR0000, !dbg !301 + %7818 = fptrunc float %7781 to bfloat, !dbg !300 + %7819 = select i1 %6711, bfloat %7818, bfloat 0xR0000, !dbg !301 + %7820 = fptrunc float %7782 to bfloat, !dbg !300 + %7821 = select i1 %6712, bfloat %7820, bfloat 0xR0000, !dbg !301 + %7822 = fptrunc float %7783 to bfloat, !dbg !300 + %7823 = select i1 %6713, bfloat %7822, bfloat 0xR0000, !dbg !301 + %7824 = fptrunc float %7784 to bfloat, !dbg !300 + %7825 = select i1 %6714, bfloat %7824, bfloat 0xR0000, !dbg !301 + %7826 = fptrunc float %7785 to bfloat, !dbg !300 + %7827 = select i1 %6715, bfloat %7826, bfloat 0xR0000, !dbg !301 + %7828 = fptrunc float %7786 to bfloat, !dbg !300 + %7829 = select i1 %6716, bfloat %7828, bfloat 0xR0000, !dbg !301 + %7830 = fptrunc float %7787 to bfloat, !dbg !300 + %7831 = select i1 %6717, bfloat %7830, bfloat 0xR0000, !dbg !301 + %7832 = fptrunc float %7788 to bfloat, !dbg !300 + %7833 = select i1 %6718, bfloat %7832, bfloat 0xR0000, !dbg !301 + %7834 = fptrunc float %7789 to bfloat, !dbg !300 + %7835 = select i1 %6719, bfloat %7834, bfloat 0xR0000, !dbg !301 + %7836 = fptrunc float %7790 to bfloat, !dbg !300 + %7837 = select i1 %6720, bfloat %7836, bfloat 0xR0000, !dbg !301 + %7838 = fptrunc float %7791 to bfloat, !dbg !300 + %7839 = select i1 %6721, bfloat %7838, bfloat 0xR0000, !dbg !301 + %7840 = fptrunc float %7792 to bfloat, !dbg !300 + %7841 = select i1 %6722, bfloat %7840, bfloat 0xR0000, !dbg !301 + %7842 = fptrunc float %7793 to bfloat, !dbg !300 + %7843 = select i1 %6723, bfloat %7842, bfloat 0xR0000, !dbg !301 + %7844 = fptrunc float %7794 to bfloat, !dbg !300 + %7845 = select i1 %6724, bfloat %7844, bfloat 0xR0000, !dbg !301 + %7846 = fptrunc float %7795 to bfloat, !dbg !300 + %7847 = select i1 %6725, bfloat %7846, bfloat 0xR0000, !dbg !301 + %7848 = fptrunc float %7796 to bfloat, !dbg !300 + %7849 = select i1 %6726, bfloat %7848, bfloat 0xR0000, !dbg !301 + %7850 = fptrunc float %7797 to bfloat, !dbg !300 + %7851 = select i1 %6727, bfloat %7850, bfloat 0xR0000, !dbg !301 + %7852 = fptrunc float %7798 to bfloat, !dbg !300 + %7853 = select i1 %6728, bfloat %7852, bfloat 0xR0000, !dbg !301 + %7854 = fptrunc float %7799 to bfloat, !dbg !300 + %7855 = select i1 %6729, bfloat %7854, bfloat 0xR0000, !dbg !301 + %7856 = fptrunc float %7800 to bfloat, !dbg !300 + %7857 = select i1 %6730, bfloat %7856, bfloat 0xR0000, !dbg !301 + %7858 = fptrunc float %7801 to bfloat, !dbg !300 + %7859 = select i1 %6731, bfloat %7858, bfloat 0xR0000, !dbg !301 + %7860 = fptrunc float %7802 to bfloat, !dbg !300 + %7861 = select i1 %6732, bfloat %7860, bfloat 0xR0000, !dbg !301 + %7862 = fptrunc float %7803 to bfloat, !dbg !300 + %7863 = select i1 %6733, bfloat %7862, bfloat 0xR0000, !dbg !301 + %7864 = fptrunc float %7804 to bfloat, !dbg !300 + %7865 = select i1 %6734, bfloat %7864, bfloat 0xR0000, !dbg !301 + %7866 = fptrunc float %7805 to bfloat, !dbg !300 + %7867 = select i1 %6735, bfloat %7866, bfloat 0xR0000, !dbg !301 + %7868 = fptrunc float %7806 to bfloat, !dbg !300 + %7869 = select i1 %6736, bfloat %7868, bfloat 0xR0000, !dbg !301 + %7870 = fptrunc float %7807 to bfloat, !dbg !300 + %7871 = select i1 %6737, bfloat %7870, bfloat 0xR0000, !dbg !301 + %7872 = insertelement <2 x bfloat> poison, bfloat %7809, i64 0, !dbg !302 + %7873 = insertelement <2 x bfloat> %7872, bfloat %7811, i64 1, !dbg !302 + %7874 = bitcast <2 x bfloat> %7873 to i32, !dbg !302 + %7875 = insertelement <2 x bfloat> poison, bfloat %7813, i64 0, !dbg !302 + %7876 = insertelement <2 x bfloat> %7875, bfloat %7815, i64 1, !dbg !302 + %7877 = bitcast <2 x bfloat> %7876 to i32, !dbg !302 + %7878 = insertelement <2 x bfloat> poison, bfloat %7817, i64 0, !dbg !302 + %7879 = insertelement <2 x bfloat> %7878, bfloat %7819, i64 1, !dbg !302 + %7880 = bitcast <2 x bfloat> %7879 to i32, !dbg !302 + %7881 = insertelement <2 x bfloat> poison, bfloat %7821, i64 0, !dbg !302 + %7882 = insertelement <2 x bfloat> %7881, bfloat %7823, i64 1, !dbg !302 + %7883 = bitcast <2 x bfloat> %7882 to i32, !dbg !302 + %7884 = insertelement <2 x bfloat> poison, bfloat %7825, i64 0, !dbg !302 + %7885 = insertelement <2 x bfloat> %7884, bfloat %7827, i64 1, !dbg !302 + %7886 = bitcast <2 x bfloat> %7885 to i32, !dbg !302 + %7887 = insertelement <2 x bfloat> poison, bfloat %7829, i64 0, !dbg !302 + %7888 = insertelement <2 x bfloat> %7887, bfloat %7831, i64 1, !dbg !302 + %7889 = bitcast <2 x bfloat> %7888 to i32, !dbg !302 + %7890 = insertelement <2 x bfloat> poison, bfloat %7833, i64 0, !dbg !302 + %7891 = insertelement <2 x bfloat> %7890, bfloat %7835, i64 1, !dbg !302 + %7892 = bitcast <2 x bfloat> %7891 to i32, !dbg !302 + %7893 = insertelement <2 x bfloat> poison, bfloat %7837, i64 0, !dbg !302 + %7894 = insertelement <2 x bfloat> %7893, bfloat %7839, i64 1, !dbg !302 + %7895 = bitcast <2 x bfloat> %7894 to i32, !dbg !302 + %7896 = insertelement <2 x bfloat> poison, bfloat %7841, i64 0, !dbg !302 + %7897 = insertelement <2 x bfloat> %7896, bfloat %7843, i64 1, !dbg !302 + %7898 = bitcast <2 x bfloat> %7897 to i32, !dbg !302 + %7899 = insertelement <2 x bfloat> poison, bfloat %7845, i64 0, !dbg !302 + %7900 = insertelement <2 x bfloat> %7899, bfloat %7847, i64 1, !dbg !302 + %7901 = bitcast <2 x bfloat> %7900 to i32, !dbg !302 + %7902 = insertelement <2 x bfloat> poison, bfloat %7849, i64 0, !dbg !302 + %7903 = insertelement <2 x bfloat> %7902, bfloat %7851, i64 1, !dbg !302 + %7904 = bitcast <2 x bfloat> %7903 to i32, !dbg !302 + %7905 = insertelement <2 x bfloat> poison, bfloat %7853, i64 0, !dbg !302 + %7906 = insertelement <2 x bfloat> %7905, bfloat %7855, i64 1, !dbg !302 + %7907 = bitcast <2 x bfloat> %7906 to i32, !dbg !302 + %7908 = insertelement <2 x bfloat> poison, bfloat %7857, i64 0, !dbg !302 + %7909 = insertelement <2 x bfloat> %7908, bfloat %7859, i64 1, !dbg !302 + %7910 = bitcast <2 x bfloat> %7909 to i32, !dbg !302 + %7911 = insertelement <2 x bfloat> poison, bfloat %7861, i64 0, !dbg !302 + %7912 = insertelement <2 x bfloat> %7911, bfloat %7863, i64 1, !dbg !302 + %7913 = bitcast <2 x bfloat> %7912 to i32, !dbg !302 + %7914 = insertelement <2 x bfloat> poison, bfloat %7865, i64 0, !dbg !302 + %7915 = insertelement <2 x bfloat> %7914, bfloat %7867, i64 1, !dbg !302 + %7916 = bitcast <2 x bfloat> %7915 to i32, !dbg !302 + %7917 = insertelement <2 x bfloat> poison, bfloat %7869, i64 0, !dbg !302 + %7918 = insertelement <2 x bfloat> %7917, bfloat %7871, i64 1, !dbg !302 + %7919 = bitcast <2 x bfloat> %7918 to i32, !dbg !302 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !302 + %7920 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5911, float %5912, float %5913, float %5914, float %5915, float %5916, float %5917, float %5918, float %5919, float %5920, float %5921, float %5922, float %5923, float %5924, float %5925, float %5926, float %5927, float %5928, float %5929, float %5930, float %5931, float %5932, float %5933, float %5934, float %5935, float %5936, float %5937, float %5938, float %5939, float %5940, float %5941, float %5942, float %5943, float %5944, float %5945, float %5946, float %5947, float %5948, float %5949, float %5950, float %5951, float %5952, float %5953, float %5954, float %5955, float %5956, float %5957, float %5958, float %5959, float %5960, float %5961, float %5962, float %5963, float %5964, float %5965, float %5966, float %5967, float %5968, float %5969, float %5970, float %5971, float %5972, float %5973, float %5974, i32 %7874, i32 %7877, i32 %7880, i32 %7883, i64 %6081, i1 true) #3, !dbg !302 + %7921 = add i32 %6077, 2048, !dbg !302 + %7922 = lshr exact i32 %7921, 4, !dbg !302 + %7923 = and i32 %7922, 16383, !dbg !302 + %7924 = zext nneg i32 %7923 to i64, !dbg !302 + %7925 = or disjoint i64 %7924, 4611686293338849280, !dbg !302 + %7926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 0, !dbg !302 + %7927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 1, !dbg !302 + %7928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 2, !dbg !302 + %7929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 3, !dbg !302 + %7930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 4, !dbg !302 + %7931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 5, !dbg !302 + %7932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 6, !dbg !302 + %7933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 7, !dbg !302 + %7934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 8, !dbg !302 + %7935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 9, !dbg !302 + %7936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 10, !dbg !302 + %7937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 11, !dbg !302 + %7938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 12, !dbg !302 + %7939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 13, !dbg !302 + %7940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 14, !dbg !302 + %7941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 15, !dbg !302 + %7942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 16, !dbg !302 + %7943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 17, !dbg !302 + %7944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 18, !dbg !302 + %7945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 19, !dbg !302 + %7946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 20, !dbg !302 + %7947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 21, !dbg !302 + %7948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 22, !dbg !302 + %7949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 23, !dbg !302 + %7950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 24, !dbg !302 + %7951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 25, !dbg !302 + %7952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 26, !dbg !302 + %7953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 27, !dbg !302 + %7954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 28, !dbg !302 + %7955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 29, !dbg !302 + %7956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 30, !dbg !302 + %7957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 31, !dbg !302 + %7958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 32, !dbg !302 + %7959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 33, !dbg !302 + %7960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 34, !dbg !302 + %7961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 35, !dbg !302 + %7962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 36, !dbg !302 + %7963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 37, !dbg !302 + %7964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 38, !dbg !302 + %7965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 39, !dbg !302 + %7966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 40, !dbg !302 + %7967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 41, !dbg !302 + %7968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 42, !dbg !302 + %7969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 43, !dbg !302 + %7970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 44, !dbg !302 + %7971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 45, !dbg !302 + %7972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 46, !dbg !302 + %7973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 47, !dbg !302 + %7974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 48, !dbg !302 + %7975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 49, !dbg !302 + %7976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 50, !dbg !302 + %7977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 51, !dbg !302 + %7978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 52, !dbg !302 + %7979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 53, !dbg !302 + %7980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 54, !dbg !302 + %7981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 55, !dbg !302 + %7982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 56, !dbg !302 + %7983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 57, !dbg !302 + %7984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 58, !dbg !302 + %7985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 59, !dbg !302 + %7986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 60, !dbg !302 + %7987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 61, !dbg !302 + %7988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 62, !dbg !302 + %7989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7920, 63, !dbg !302 + %7990 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7926, float %7927, float %7928, float %7929, float %7930, float %7931, float %7932, float %7933, float %7934, float %7935, float %7936, float %7937, float %7938, float %7939, float %7940, float %7941, float %7942, float %7943, float %7944, float %7945, float %7946, float %7947, float %7948, float %7949, float %7950, float %7951, float %7952, float %7953, float %7954, float %7955, float %7956, float %7957, float %7958, float %7959, float %7960, float %7961, float %7962, float %7963, float %7964, float %7965, float %7966, float %7967, float %7968, float %7969, float %7970, float %7971, float %7972, float %7973, float %7974, float %7975, float %7976, float %7977, float %7978, float %7979, float %7980, float %7981, float %7982, float %7983, float %7984, float %7985, float %7986, float %7987, float %7988, float %7989, i32 %7886, i32 %7889, i32 %7892, i32 %7895, i64 %7925, i1 true) #3, !dbg !302 + %7991 = add i32 %6077, 4096, !dbg !302 + %7992 = lshr exact i32 %7991, 4, !dbg !302 + %7993 = and i32 %7992, 16383, !dbg !302 + %7994 = zext nneg i32 %7993 to i64, !dbg !302 + %7995 = or disjoint i64 %7994, 4611686293338849280, !dbg !302 + %7996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 0, !dbg !302 + %7997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 1, !dbg !302 + %7998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 2, !dbg !302 + %7999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 3, !dbg !302 + %8000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 4, !dbg !302 + %8001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 5, !dbg !302 + %8002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 6, !dbg !302 + %8003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 7, !dbg !302 + %8004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 8, !dbg !302 + %8005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 9, !dbg !302 + %8006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 10, !dbg !302 + %8007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 11, !dbg !302 + %8008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 12, !dbg !302 + %8009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 13, !dbg !302 + %8010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 14, !dbg !302 + %8011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 15, !dbg !302 + %8012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 16, !dbg !302 + %8013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 17, !dbg !302 + %8014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 18, !dbg !302 + %8015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 19, !dbg !302 + %8016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 20, !dbg !302 + %8017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 21, !dbg !302 + %8018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 22, !dbg !302 + %8019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 23, !dbg !302 + %8020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 24, !dbg !302 + %8021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 25, !dbg !302 + %8022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 26, !dbg !302 + %8023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 27, !dbg !302 + %8024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 28, !dbg !302 + %8025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 29, !dbg !302 + %8026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 30, !dbg !302 + %8027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 31, !dbg !302 + %8028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 32, !dbg !302 + %8029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 33, !dbg !302 + %8030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 34, !dbg !302 + %8031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 35, !dbg !302 + %8032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 36, !dbg !302 + %8033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 37, !dbg !302 + %8034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 38, !dbg !302 + %8035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 39, !dbg !302 + %8036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 40, !dbg !302 + %8037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 41, !dbg !302 + %8038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 42, !dbg !302 + %8039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 43, !dbg !302 + %8040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 44, !dbg !302 + %8041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 45, !dbg !302 + %8042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 46, !dbg !302 + %8043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 47, !dbg !302 + %8044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 48, !dbg !302 + %8045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 49, !dbg !302 + %8046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 50, !dbg !302 + %8047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 51, !dbg !302 + %8048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 52, !dbg !302 + %8049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 53, !dbg !302 + %8050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 54, !dbg !302 + %8051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 55, !dbg !302 + %8052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 56, !dbg !302 + %8053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 57, !dbg !302 + %8054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 58, !dbg !302 + %8055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 59, !dbg !302 + %8056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 60, !dbg !302 + %8057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 61, !dbg !302 + %8058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 62, !dbg !302 + %8059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7990, 63, !dbg !302 + %8060 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7996, float %7997, float %7998, float %7999, float %8000, float %8001, float %8002, float %8003, float %8004, float %8005, float %8006, float %8007, float %8008, float %8009, float %8010, float %8011, float %8012, float %8013, float %8014, float %8015, float %8016, float %8017, float %8018, float %8019, float %8020, float %8021, float %8022, float %8023, float %8024, float %8025, float %8026, float %8027, float %8028, float %8029, float %8030, float %8031, float %8032, float %8033, float %8034, float %8035, float %8036, float %8037, float %8038, float %8039, float %8040, float %8041, float %8042, float %8043, float %8044, float %8045, float %8046, float %8047, float %8048, float %8049, float %8050, float %8051, float %8052, float %8053, float %8054, float %8055, float %8056, float %8057, float %8058, float %8059, i32 %7898, i32 %7901, i32 %7904, i32 %7907, i64 %7995, i1 true) #3, !dbg !302 + %8061 = add i32 %6077, 6144, !dbg !302 + %8062 = lshr exact i32 %8061, 4, !dbg !302 + %8063 = and i32 %8062, 16383, !dbg !302 + %8064 = zext nneg i32 %8063 to i64, !dbg !302 + %8065 = or disjoint i64 %8064, 4611686293338849280, !dbg !302 + %8066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 0, !dbg !302 + %8067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 1, !dbg !302 + %8068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 2, !dbg !302 + %8069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 3, !dbg !302 + %8070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 4, !dbg !302 + %8071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 5, !dbg !302 + %8072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 6, !dbg !302 + %8073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 7, !dbg !302 + %8074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 8, !dbg !302 + %8075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 9, !dbg !302 + %8076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 10, !dbg !302 + %8077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 11, !dbg !302 + %8078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 12, !dbg !302 + %8079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 13, !dbg !302 + %8080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 14, !dbg !302 + %8081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 15, !dbg !302 + %8082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 16, !dbg !302 + %8083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 17, !dbg !302 + %8084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 18, !dbg !302 + %8085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 19, !dbg !302 + %8086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 20, !dbg !302 + %8087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 21, !dbg !302 + %8088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 22, !dbg !302 + %8089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 23, !dbg !302 + %8090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 24, !dbg !302 + %8091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 25, !dbg !302 + %8092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 26, !dbg !302 + %8093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 27, !dbg !302 + %8094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 28, !dbg !302 + %8095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 29, !dbg !302 + %8096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 30, !dbg !302 + %8097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 31, !dbg !302 + %8098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 32, !dbg !302 + %8099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 33, !dbg !302 + %8100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 34, !dbg !302 + %8101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 35, !dbg !302 + %8102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 36, !dbg !302 + %8103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 37, !dbg !302 + %8104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 38, !dbg !302 + %8105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 39, !dbg !302 + %8106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 40, !dbg !302 + %8107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 41, !dbg !302 + %8108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 42, !dbg !302 + %8109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 43, !dbg !302 + %8110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 44, !dbg !302 + %8111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 45, !dbg !302 + %8112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 46, !dbg !302 + %8113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 47, !dbg !302 + %8114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 48, !dbg !302 + %8115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 49, !dbg !302 + %8116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 50, !dbg !302 + %8117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 51, !dbg !302 + %8118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 52, !dbg !302 + %8119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 53, !dbg !302 + %8120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 54, !dbg !302 + %8121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 55, !dbg !302 + %8122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 56, !dbg !302 + %8123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 57, !dbg !302 + %8124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 58, !dbg !302 + %8125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 59, !dbg !302 + %8126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 60, !dbg !302 + %8127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 61, !dbg !302 + %8128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 62, !dbg !302 + %8129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8060, 63, !dbg !302 + %8130 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %8066, float %8067, float %8068, float %8069, float %8070, float %8071, float %8072, float %8073, float %8074, float %8075, float %8076, float %8077, float %8078, float %8079, float %8080, float %8081, float %8082, float %8083, float %8084, float %8085, float %8086, float %8087, float %8088, float %8089, float %8090, float %8091, float %8092, float %8093, float %8094, float %8095, float %8096, float %8097, float %8098, float %8099, float %8100, float %8101, float %8102, float %8103, float %8104, float %8105, float %8106, float %8107, float %8108, float %8109, float %8110, float %8111, float %8112, float %8113, float %8114, float %8115, float %8116, float %8117, float %8118, float %8119, float %8120, float %8121, float %8122, float %8123, float %8124, float %8125, float %8126, float %8127, float %8128, float %8129, i32 %7910, i32 %7913, i32 %7916, i32 %7919, i64 %8065, i1 true) #3, !dbg !302 + %8131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 0, !dbg !302 + %8132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 1, !dbg !302 + %8133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 2, !dbg !302 + %8134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 3, !dbg !302 + %8135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 4, !dbg !302 + %8136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 5, !dbg !302 + %8137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 6, !dbg !302 + %8138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 7, !dbg !302 + %8139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 8, !dbg !302 + %8140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 9, !dbg !302 + %8141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 10, !dbg !302 + %8142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 11, !dbg !302 + %8143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 12, !dbg !302 + %8144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 13, !dbg !302 + %8145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 14, !dbg !302 + %8146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 15, !dbg !302 + %8147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 16, !dbg !302 + %8148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 17, !dbg !302 + %8149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 18, !dbg !302 + %8150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 19, !dbg !302 + %8151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 20, !dbg !302 + %8152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 21, !dbg !302 + %8153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 22, !dbg !302 + %8154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 23, !dbg !302 + %8155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 24, !dbg !302 + %8156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 25, !dbg !302 + %8157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 26, !dbg !302 + %8158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 27, !dbg !302 + %8159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 28, !dbg !302 + %8160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 29, !dbg !302 + %8161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 30, !dbg !302 + %8162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 31, !dbg !302 + %8163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 32, !dbg !302 + %8164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 33, !dbg !302 + %8165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 34, !dbg !302 + %8166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 35, !dbg !302 + %8167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 36, !dbg !302 + %8168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 37, !dbg !302 + %8169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 38, !dbg !302 + %8170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 39, !dbg !302 + %8171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 40, !dbg !302 + %8172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 41, !dbg !302 + %8173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 42, !dbg !302 + %8174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 43, !dbg !302 + %8175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 44, !dbg !302 + %8176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 45, !dbg !302 + %8177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 46, !dbg !302 + %8178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 47, !dbg !302 + %8179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 48, !dbg !302 + %8180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 49, !dbg !302 + %8181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 50, !dbg !302 + %8182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 51, !dbg !302 + %8183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 52, !dbg !302 + %8184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 53, !dbg !302 + %8185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 54, !dbg !302 + %8186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 55, !dbg !302 + %8187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 56, !dbg !302 + %8188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 57, !dbg !302 + %8189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 58, !dbg !302 + %8190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 59, !dbg !302 + %8191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 60, !dbg !302 + %8192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 61, !dbg !302 + %8193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 62, !dbg !302 + %8194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8130, 63, !dbg !302 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !302 + %8195 = insertelement <16 x i32> poison, i32 %5818, i64 0, !dbg !303 + %8196 = shufflevector <16 x i32> %8195, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !303 + %8197 = add <16 x i32> %5976, %8196, !dbg !303 + %8198 = add nuw nsw i32 %5975, 1, !dbg !276 + %8199 = lshr i32 %8198, 1, !dbg !304 + %8200 = zext nneg i32 %8199 to i64, !dbg !305 + %8201 = getelementptr i32, ptr addrspace(1) %4991, i64 %8200, !dbg !305 + %8202 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !306 + %8203 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %8201, i64 %8202, i1 %5978) #3, !dbg !306 + %8204 = add nuw nsw i32 %8199, 1, !dbg !307 + %8205 = icmp slt i32 %8204, %4996, !dbg !308 + %8206 = getelementptr i8, ptr addrspace(1) %8201, i64 4, !dbg !309 + %8207 = and i1 %5978, %8205, !dbg !276 + %8208 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !310 + %8209 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %8206, i64 %8208, i1 %8207) #3, !dbg !310 + %8210 = and i32 %5975, 1, !dbg !311 + %8211 = sub i32 %8209, %8203, !dbg !312 + %8212 = shl i32 %8211, 7, !dbg !313 + %8213 = add i32 %8212, -64, !dbg !314 + %8214 = xor i32 %8210, 1, !dbg !315 + %8215 = mul nuw nsw i32 %8213, %8214, !dbg !315 + %8216 = shl nuw nsw i32 %8210, 6, !dbg !316 + %8217 = add i32 %8215, %8216, !dbg !317 + %8218 = shl i32 %8217, 12, !dbg !318 + %8219 = sext i32 %8218 to i64, !dbg !274 + %8220 = getelementptr bfloat, ptr addrspace(1) %.pn1231675, i64 %8219, !dbg !274 + %8221 = getelementptr bfloat, ptr addrspace(1) %.pn1071676, i64 %8219, !dbg !274 + %8222 = getelementptr bfloat, ptr addrspace(1) %.pn911677, i64 %8219, !dbg !274 + %8223 = getelementptr bfloat, ptr addrspace(1) %.pn751678, i64 %8219, !dbg !274 + %8224 = shl i32 %8217, 7, !dbg !319 + %8225 = sext i32 %8224 to i64, !dbg !275 + %8226 = getelementptr bfloat, ptr addrspace(1) %.pn1871679, i64 %8225, !dbg !275 + %8227 = getelementptr bfloat, ptr addrspace(1) %.pn1711680, i64 %8225, !dbg !275 + %8228 = getelementptr bfloat, ptr addrspace(1) %.pn1551681, i64 %8225, !dbg !275 + %8229 = getelementptr bfloat, ptr addrspace(1) %.pn1391682, i64 %8225, !dbg !275 + %8230 = add i32 %8217, %.pn2191683, !dbg !303 + %8231 = add i32 %8217, %.pn2171684, !dbg !303 + %8232 = add i32 %8217, %.pn2151685, !dbg !303 + %8233 = add i32 %8217, %.pn2131686, !dbg !303 + %8234 = add i32 %8217, %.pn2111687, !dbg !303 + %8235 = add i32 %8217, %.pn2091688, !dbg !303 + %8236 = add i32 %8217, %.pn2071689, !dbg !303 + %8237 = add i32 %8217, %.pn2051690, !dbg !303 + %8238 = add i32 %8217, %.pn2031691, !dbg !303 + %8239 = add i32 %8217, %.pn2011692, !dbg !303 + %8240 = add i32 %8217, %.pn1991693, !dbg !303 + %8241 = add i32 %8217, %.pn1971694, !dbg !303 + %8242 = add i32 %8217, %.pn1951695, !dbg !303 + %8243 = add i32 %8217, %.pn1931696, !dbg !303 + %8244 = add i32 %8217, %.pn1911697, !dbg !303 + %8245 = add i32 %8217, %.pn1891698, !dbg !303 + %8246 = add i32 %8217, %5843, !dbg !303 + %8247 = add i32 %8217, %5844, !dbg !303 + %8248 = add i32 %8217, %5845, !dbg !303 + %8249 = add i32 %8217, %5846, !dbg !303 + %8250 = add i32 %8217, %5839, !dbg !303 + %8251 = add i32 %8217, %5840, !dbg !303 + %8252 = add i32 %8217, %5841, !dbg !303 + %8253 = add i32 %8217, %5842, !dbg !303 + %8254 = add i32 %5836, 1, !dbg !276 + %8255 = icmp sgt i32 %8254, 1, !dbg !276 + %8256 = select i1 %8255, i32 0, i32 %8254, !dbg !276 + %8257 = add i32 %5838, 1, !dbg !276 + %8258 = icmp sgt i32 %8257, 2, !dbg !276 + %8259 = select i1 %8258, i32 0, i32 %8257, !dbg !276 + %8260 = icmp slt i32 %8246, %17, !dbg !277 + %8261 = icmp slt i32 %8247, %17, !dbg !277 + %8262 = icmp slt i32 %8248, %17, !dbg !277 + %8263 = icmp slt i32 %8249, %17, !dbg !277 + %8264 = shl i32 %8259, 13, !dbg !268 + %8265 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %8264, !dbg !268 + %8266 = and i1 %5977, %8260, !dbg !276 + %8267 = and i1 %5977, %8261, !dbg !276 + %8268 = and i1 %5977, %8262, !dbg !276 + %8269 = and i1 %5977, %8263, !dbg !276 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !268 + %8270 = getelementptr inbounds nuw i8, ptr addrspace(3) %8265, i32 %5111, !dbg !268 + %8271 = select i1 %8266, i32 16, i32 0, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %8270, ptr addrspace(1) %8220, i32 %8271) #3, !dbg !268 + %8272 = getelementptr inbounds nuw i8, ptr addrspace(3) %8265, i32 %5114, !dbg !268 + %8273 = select i1 %8267, i32 16, i32 0, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8272, ptr addrspace(1) %8221, i32 %8273) #3, !dbg !268 + %8274 = getelementptr inbounds nuw i8, ptr addrspace(3) %8265, i32 %5117, !dbg !268 + %8275 = select i1 %8268, i32 16, i32 0, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8274, ptr addrspace(1) %8222, i32 %8275) #3, !dbg !268 + %8276 = getelementptr inbounds nuw i8, ptr addrspace(3) %8265, i32 %5120, !dbg !268 + %8277 = select i1 %8269, i32 16, i32 0, !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8276, ptr addrspace(1) %8223, i32 %8277) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + %8278 = icmp slt i32 %8230, %17, !dbg !320 + %8279 = icmp slt i32 %8231, %17, !dbg !320 + %8280 = icmp slt i32 %8232, %17, !dbg !320 + %8281 = icmp slt i32 %8233, %17, !dbg !320 + %8282 = icmp slt i32 %8234, %17, !dbg !320 + %8283 = icmp slt i32 %8235, %17, !dbg !320 + %8284 = icmp slt i32 %8236, %17, !dbg !320 + %8285 = icmp slt i32 %8237, %17, !dbg !320 + %8286 = icmp slt i32 %8238, %17, !dbg !320 + %8287 = icmp slt i32 %8239, %17, !dbg !320 + %8288 = icmp slt i32 %8240, %17, !dbg !320 + %8289 = icmp slt i32 %8241, %17, !dbg !320 + %8290 = icmp slt i32 %8242, %17, !dbg !320 + %8291 = icmp slt i32 %8243, %17, !dbg !320 + %8292 = icmp slt i32 %8244, %17, !dbg !320 + %8293 = icmp slt i32 %8245, %17, !dbg !320 + %8294 = sext i32 %8230 to i64, !dbg !269 + %8295 = getelementptr float, ptr addrspace(1) %5728, i64 %8294, !dbg !269 + %8296 = sext i32 %8231 to i64, !dbg !269 + %8297 = getelementptr float, ptr addrspace(1) %5728, i64 %8296, !dbg !269 + %8298 = sext i32 %8232 to i64, !dbg !269 + %8299 = getelementptr float, ptr addrspace(1) %5728, i64 %8298, !dbg !269 + %8300 = sext i32 %8233 to i64, !dbg !269 + %8301 = getelementptr float, ptr addrspace(1) %5728, i64 %8300, !dbg !269 + %8302 = sext i32 %8234 to i64, !dbg !269 + %8303 = getelementptr float, ptr addrspace(1) %5728, i64 %8302, !dbg !269 + %8304 = sext i32 %8235 to i64, !dbg !269 + %8305 = getelementptr float, ptr addrspace(1) %5728, i64 %8304, !dbg !269 + %8306 = sext i32 %8236 to i64, !dbg !269 + %8307 = getelementptr float, ptr addrspace(1) %5728, i64 %8306, !dbg !269 + %8308 = sext i32 %8237 to i64, !dbg !269 + %8309 = getelementptr float, ptr addrspace(1) %5728, i64 %8308, !dbg !269 + %8310 = sext i32 %8238 to i64, !dbg !269 + %8311 = getelementptr float, ptr addrspace(1) %5728, i64 %8310, !dbg !269 + %8312 = sext i32 %8239 to i64, !dbg !269 + %8313 = getelementptr float, ptr addrspace(1) %5728, i64 %8312, !dbg !269 + %8314 = sext i32 %8240 to i64, !dbg !269 + %8315 = getelementptr float, ptr addrspace(1) %5728, i64 %8314, !dbg !269 + %8316 = sext i32 %8241 to i64, !dbg !269 + %8317 = getelementptr float, ptr addrspace(1) %5728, i64 %8316, !dbg !269 + %8318 = sext i32 %8242 to i64, !dbg !269 + %8319 = getelementptr float, ptr addrspace(1) %5728, i64 %8318, !dbg !269 + %8320 = sext i32 %8243 to i64, !dbg !269 + %8321 = getelementptr float, ptr addrspace(1) %5728, i64 %8320, !dbg !269 + %8322 = sext i32 %8244 to i64, !dbg !269 + %8323 = getelementptr float, ptr addrspace(1) %5728, i64 %8322, !dbg !269 + %8324 = sext i32 %8245 to i64, !dbg !269 + %8325 = getelementptr float, ptr addrspace(1) %5728, i64 %8324, !dbg !269 + %8326 = shl i32 %8256, 6, !dbg !270 + %8327 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %8326, !dbg !270 + %8328 = and i1 %5977, %8278, !dbg !276 + %8329 = and i1 %5977, %8279, !dbg !276 + %8330 = and i1 %5977, %8280, !dbg !276 + %8331 = and i1 %5977, %8281, !dbg !276 + %8332 = and i1 %5977, %8282, !dbg !276 + %8333 = and i1 %5977, %8283, !dbg !276 + %8334 = and i1 %5977, %8284, !dbg !276 + %8335 = and i1 %5977, %8285, !dbg !276 + %8336 = and i1 %5977, %8286, !dbg !276 + %8337 = and i1 %5977, %8287, !dbg !276 + %8338 = and i1 %5977, %8288, !dbg !276 + %8339 = and i1 %5977, %8289, !dbg !276 + %8340 = and i1 %5977, %8290, !dbg !276 + %8341 = and i1 %5977, %8291, !dbg !276 + %8342 = and i1 %5977, %8292, !dbg !276 + %8343 = and i1 %5977, %8293, !dbg !276 + %8344 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5189, !dbg !270 + %8345 = select i1 %8328, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %8344, ptr addrspace(1) %8295, i32 %8345, i1 %5188) #3, !dbg !270 + %8346 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5192, !dbg !270 + %8347 = select i1 %8329, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8346, ptr addrspace(1) %8297, i32 %8347, i1 %5188) #3, !dbg !270 + %8348 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5195, !dbg !270 + %8349 = select i1 %8330, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8348, ptr addrspace(1) %8299, i32 %8349, i1 %5188) #3, !dbg !270 + %8350 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5198, !dbg !270 + %8351 = select i1 %8331, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8350, ptr addrspace(1) %8301, i32 %8351, i1 %5188) #3, !dbg !270 + %8352 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5201, !dbg !270 + %8353 = select i1 %8332, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8352, ptr addrspace(1) %8303, i32 %8353, i1 %5188) #3, !dbg !270 + %8354 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5204, !dbg !270 + %8355 = select i1 %8333, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8354, ptr addrspace(1) %8305, i32 %8355, i1 %5188) #3, !dbg !270 + %8356 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5207, !dbg !270 + %8357 = select i1 %8334, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8356, ptr addrspace(1) %8307, i32 %8357, i1 %5188) #3, !dbg !270 + %8358 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5210, !dbg !270 + %8359 = select i1 %8335, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8358, ptr addrspace(1) %8309, i32 %8359, i1 %5188) #3, !dbg !270 + %8360 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5213, !dbg !270 + %8361 = select i1 %8336, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8360, ptr addrspace(1) %8311, i32 %8361, i1 %5188) #3, !dbg !270 + %8362 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5216, !dbg !270 + %8363 = select i1 %8337, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8362, ptr addrspace(1) %8313, i32 %8363, i1 %5188) #3, !dbg !270 + %8364 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5219, !dbg !270 + %8365 = select i1 %8338, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8364, ptr addrspace(1) %8315, i32 %8365, i1 %5188) #3, !dbg !270 + %8366 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5222, !dbg !270 + %8367 = select i1 %8339, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8366, ptr addrspace(1) %8317, i32 %8367, i1 %5188) #3, !dbg !270 + %8368 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5225, !dbg !270 + %8369 = select i1 %8340, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8368, ptr addrspace(1) %8319, i32 %8369, i1 %5188) #3, !dbg !270 + %8370 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5228, !dbg !270 + %8371 = select i1 %8341, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8370, ptr addrspace(1) %8321, i32 %8371, i1 %5188) #3, !dbg !270 + %8372 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5231, !dbg !270 + %8373 = select i1 %8342, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8372, ptr addrspace(1) %8323, i32 %8373, i1 %5188) #3, !dbg !270 + %8374 = getelementptr inbounds nuw i8, ptr addrspace(3) %8327, i32 %5234, !dbg !270 + %8375 = select i1 %8343, i32 4, i32 0, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8374, ptr addrspace(1) %8325, i32 %8375, i1 %5188) #3, !dbg !270 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !270 + %8376 = icmp slt i32 %8250, %17, !dbg !321 + %8377 = icmp slt i32 %8251, %17, !dbg !321 + %8378 = icmp slt i32 %8252, %17, !dbg !321 + %8379 = icmp slt i32 %8253, %17, !dbg !321 + %8380 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %8264, !dbg !271 + %8381 = and i1 %5977, %8376, !dbg !276 + %8382 = and i1 %5977, %8377, !dbg !276 + %8383 = and i1 %5977, %8378, !dbg !276 + %8384 = and i1 %5977, %8379, !dbg !276 + %8385 = getelementptr inbounds nuw i8, ptr addrspace(3) %8380, i32 %5111, !dbg !271 + %8386 = select i1 %8381, i32 16, i32 0, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %8385, ptr addrspace(1) %8226, i32 %8386) #3, !dbg !271 + %8387 = getelementptr inbounds nuw i8, ptr addrspace(3) %8380, i32 %5114, !dbg !271 + %8388 = select i1 %8382, i32 16, i32 0, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8387, ptr addrspace(1) %8227, i32 %8388) #3, !dbg !271 + %8389 = getelementptr inbounds nuw i8, ptr addrspace(3) %8380, i32 %5117, !dbg !271 + %8390 = select i1 %8383, i32 16, i32 0, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8389, ptr addrspace(1) %8228, i32 %8390) #3, !dbg !271 + %8391 = getelementptr inbounds nuw i8, ptr addrspace(3) %8380, i32 %5120, !dbg !271 + %8392 = select i1 %8384, i32 16, i32 0, !dbg !271 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8391, ptr addrspace(1) %8229, i32 %8392) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + %8393 = getelementptr float, ptr addrspace(1) %5729, i64 %8294, !dbg !272 + %8394 = getelementptr float, ptr addrspace(1) %5729, i64 %8296, !dbg !272 + %8395 = getelementptr float, ptr addrspace(1) %5729, i64 %8298, !dbg !272 + %8396 = getelementptr float, ptr addrspace(1) %5729, i64 %8300, !dbg !272 + %8397 = getelementptr float, ptr addrspace(1) %5729, i64 %8302, !dbg !272 + %8398 = getelementptr float, ptr addrspace(1) %5729, i64 %8304, !dbg !272 + %8399 = getelementptr float, ptr addrspace(1) %5729, i64 %8306, !dbg !272 + %8400 = getelementptr float, ptr addrspace(1) %5729, i64 %8308, !dbg !272 + %8401 = getelementptr float, ptr addrspace(1) %5729, i64 %8310, !dbg !272 + %8402 = getelementptr float, ptr addrspace(1) %5729, i64 %8312, !dbg !272 + %8403 = getelementptr float, ptr addrspace(1) %5729, i64 %8314, !dbg !272 + %8404 = getelementptr float, ptr addrspace(1) %5729, i64 %8316, !dbg !272 + %8405 = getelementptr float, ptr addrspace(1) %5729, i64 %8318, !dbg !272 + %8406 = getelementptr float, ptr addrspace(1) %5729, i64 %8320, !dbg !272 + %8407 = getelementptr float, ptr addrspace(1) %5729, i64 %8322, !dbg !272 + %8408 = getelementptr float, ptr addrspace(1) %5729, i64 %8324, !dbg !272 + %8409 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %8326, !dbg !273 + %8410 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5189, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %8410, ptr addrspace(1) %8393, i32 %8345, i1 %5188) #3, !dbg !273 + %8411 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5192, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8411, ptr addrspace(1) %8394, i32 %8347, i1 %5188) #3, !dbg !273 + %8412 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5195, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8412, ptr addrspace(1) %8395, i32 %8349, i1 %5188) #3, !dbg !273 + %8413 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5198, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8413, ptr addrspace(1) %8396, i32 %8351, i1 %5188) #3, !dbg !273 + %8414 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5201, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8414, ptr addrspace(1) %8397, i32 %8353, i1 %5188) #3, !dbg !273 + %8415 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5204, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8415, ptr addrspace(1) %8398, i32 %8355, i1 %5188) #3, !dbg !273 + %8416 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5207, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8416, ptr addrspace(1) %8399, i32 %8357, i1 %5188) #3, !dbg !273 + %8417 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5210, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8417, ptr addrspace(1) %8400, i32 %8359, i1 %5188) #3, !dbg !273 + %8418 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5213, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8418, ptr addrspace(1) %8401, i32 %8361, i1 %5188) #3, !dbg !273 + %8419 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5216, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8419, ptr addrspace(1) %8402, i32 %8363, i1 %5188) #3, !dbg !273 + %8420 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5219, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8420, ptr addrspace(1) %8403, i32 %8365, i1 %5188) #3, !dbg !273 + %8421 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5222, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8421, ptr addrspace(1) %8404, i32 %8367, i1 %5188) #3, !dbg !273 + %8422 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5225, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8422, ptr addrspace(1) %8405, i32 %8369, i1 %5188) #3, !dbg !273 + %8423 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5228, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8423, ptr addrspace(1) %8406, i32 %8371, i1 %5188) #3, !dbg !273 + %8424 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5231, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8424, ptr addrspace(1) %8407, i32 %8373, i1 %5188) #3, !dbg !273 + %8425 = getelementptr inbounds nuw i8, ptr addrspace(3) %8409, i32 %5234, !dbg !273 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8425, ptr addrspace(1) %8408, i32 %8375, i1 %5188) #3, !dbg !273 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !273 + %exitcond2266.not = icmp eq i32 %8198, %smax2265, !dbg !276 + br i1 %exitcond2266.not, label %._crit_edge1701, label %.lr.ph1700, !dbg !276 + +._crit_edge1701: ; preds = %__nv_exp2f.exit1417, %5585 + %8426 = phi float [ %5586, %5585 ], [ %8131, %__nv_exp2f.exit1417 ] + %8427 = phi float [ %5587, %5585 ], [ %8132, %__nv_exp2f.exit1417 ] + %8428 = phi float [ %5588, %5585 ], [ %8133, %__nv_exp2f.exit1417 ] + %8429 = phi float [ %5589, %5585 ], [ %8134, %__nv_exp2f.exit1417 ] + %8430 = phi float [ %5590, %5585 ], [ %8135, %__nv_exp2f.exit1417 ] + %8431 = phi float [ %5591, %5585 ], [ %8136, %__nv_exp2f.exit1417 ] + %8432 = phi float [ %5592, %5585 ], [ %8137, %__nv_exp2f.exit1417 ] + %8433 = phi float [ %5593, %5585 ], [ %8138, %__nv_exp2f.exit1417 ] + %8434 = phi float [ %5594, %5585 ], [ %8139, %__nv_exp2f.exit1417 ] + %8435 = phi float [ %5595, %5585 ], [ %8140, %__nv_exp2f.exit1417 ] + %8436 = phi float [ %5596, %5585 ], [ %8141, %__nv_exp2f.exit1417 ] + %8437 = phi float [ %5597, %5585 ], [ %8142, %__nv_exp2f.exit1417 ] + %8438 = phi float [ %5598, %5585 ], [ %8143, %__nv_exp2f.exit1417 ] + %8439 = phi float [ %5599, %5585 ], [ %8144, %__nv_exp2f.exit1417 ] + %8440 = phi float [ %5600, %5585 ], [ %8145, %__nv_exp2f.exit1417 ] + %8441 = phi float [ %5601, %5585 ], [ %8146, %__nv_exp2f.exit1417 ] + %8442 = phi float [ %5602, %5585 ], [ %8147, %__nv_exp2f.exit1417 ] + %8443 = phi float [ %5603, %5585 ], [ %8148, %__nv_exp2f.exit1417 ] + %8444 = phi float [ %5604, %5585 ], [ %8149, %__nv_exp2f.exit1417 ] + %8445 = phi float [ %5605, %5585 ], [ %8150, %__nv_exp2f.exit1417 ] + %8446 = phi float [ %5606, %5585 ], [ %8151, %__nv_exp2f.exit1417 ] + %8447 = phi float [ %5607, %5585 ], [ %8152, %__nv_exp2f.exit1417 ] + %8448 = phi float [ %5608, %5585 ], [ %8153, %__nv_exp2f.exit1417 ] + %8449 = phi float [ %5609, %5585 ], [ %8154, %__nv_exp2f.exit1417 ] + %8450 = phi float [ %5610, %5585 ], [ %8155, %__nv_exp2f.exit1417 ] + %8451 = phi float [ %5611, %5585 ], [ %8156, %__nv_exp2f.exit1417 ] + %8452 = phi float [ %5612, %5585 ], [ %8157, %__nv_exp2f.exit1417 ] + %8453 = phi float [ %5613, %5585 ], [ %8158, %__nv_exp2f.exit1417 ] + %8454 = phi float [ %5614, %5585 ], [ %8159, %__nv_exp2f.exit1417 ] + %8455 = phi float [ %5615, %5585 ], [ %8160, %__nv_exp2f.exit1417 ] + %8456 = phi float [ %5616, %5585 ], [ %8161, %__nv_exp2f.exit1417 ] + %8457 = phi float [ %5617, %5585 ], [ %8162, %__nv_exp2f.exit1417 ] + %8458 = phi float [ %5618, %5585 ], [ %8163, %__nv_exp2f.exit1417 ] + %8459 = phi float [ %5619, %5585 ], [ %8164, %__nv_exp2f.exit1417 ] + %8460 = phi float [ %5620, %5585 ], [ %8165, %__nv_exp2f.exit1417 ] + %8461 = phi float [ %5621, %5585 ], [ %8166, %__nv_exp2f.exit1417 ] + %8462 = phi float [ %5622, %5585 ], [ %8167, %__nv_exp2f.exit1417 ] + %8463 = phi float [ %5623, %5585 ], [ %8168, %__nv_exp2f.exit1417 ] + %8464 = phi float [ %5624, %5585 ], [ %8169, %__nv_exp2f.exit1417 ] + %8465 = phi float [ %5625, %5585 ], [ %8170, %__nv_exp2f.exit1417 ] + %8466 = phi float [ %5626, %5585 ], [ %8171, %__nv_exp2f.exit1417 ] + %8467 = phi float [ %5627, %5585 ], [ %8172, %__nv_exp2f.exit1417 ] + %8468 = phi float [ %5628, %5585 ], [ %8173, %__nv_exp2f.exit1417 ] + %8469 = phi float [ %5629, %5585 ], [ %8174, %__nv_exp2f.exit1417 ] + %8470 = phi float [ %5630, %5585 ], [ %8175, %__nv_exp2f.exit1417 ] + %8471 = phi float [ %5631, %5585 ], [ %8176, %__nv_exp2f.exit1417 ] + %8472 = phi float [ %5632, %5585 ], [ %8177, %__nv_exp2f.exit1417 ] + %8473 = phi float [ %5633, %5585 ], [ %8178, %__nv_exp2f.exit1417 ] + %8474 = phi float [ %5634, %5585 ], [ %8179, %__nv_exp2f.exit1417 ] + %8475 = phi float [ %5635, %5585 ], [ %8180, %__nv_exp2f.exit1417 ] + %8476 = phi float [ %5636, %5585 ], [ %8181, %__nv_exp2f.exit1417 ] + %8477 = phi float [ %5637, %5585 ], [ %8182, %__nv_exp2f.exit1417 ] + %8478 = phi float [ %5638, %5585 ], [ %8183, %__nv_exp2f.exit1417 ] + %8479 = phi float [ %5639, %5585 ], [ %8184, %__nv_exp2f.exit1417 ] + %8480 = phi float [ %5640, %5585 ], [ %8185, %__nv_exp2f.exit1417 ] + %8481 = phi float [ %5641, %5585 ], [ %8186, %__nv_exp2f.exit1417 ] + %8482 = phi float [ %5642, %5585 ], [ %8187, %__nv_exp2f.exit1417 ] + %8483 = phi float [ %5643, %5585 ], [ %8188, %__nv_exp2f.exit1417 ] + %8484 = phi float [ %5644, %5585 ], [ %8189, %__nv_exp2f.exit1417 ] + %8485 = phi float [ %5645, %5585 ], [ %8190, %__nv_exp2f.exit1417 ] + %8486 = phi float [ %5646, %5585 ], [ %8191, %__nv_exp2f.exit1417 ] + %8487 = phi float [ %5647, %5585 ], [ %8192, %__nv_exp2f.exit1417 ] + %8488 = phi float [ %5648, %5585 ], [ %8193, %__nv_exp2f.exit1417 ] + %8489 = phi float [ %5649, %5585 ], [ %8194, %__nv_exp2f.exit1417 ] + %8490 = phi float [ %5650, %5585 ], [ %7275, %__nv_exp2f.exit1417 ] + %8491 = phi float [ %5651, %5585 ], [ %7276, %__nv_exp2f.exit1417 ] + %8492 = phi float [ %5652, %5585 ], [ %7277, %__nv_exp2f.exit1417 ] + %8493 = phi float [ %5653, %5585 ], [ %7278, %__nv_exp2f.exit1417 ] + %8494 = phi float [ %5654, %5585 ], [ %7279, %__nv_exp2f.exit1417 ] + %8495 = phi float [ %5655, %5585 ], [ %7280, %__nv_exp2f.exit1417 ] + %8496 = phi float [ %5656, %5585 ], [ %7281, %__nv_exp2f.exit1417 ] + %8497 = phi float [ %5657, %5585 ], [ %7282, %__nv_exp2f.exit1417 ] + %8498 = phi float [ %5658, %5585 ], [ %7283, %__nv_exp2f.exit1417 ] + %8499 = phi float [ %5659, %5585 ], [ %7284, %__nv_exp2f.exit1417 ] + %8500 = phi float [ %5660, %5585 ], [ %7285, %__nv_exp2f.exit1417 ] + %8501 = phi float [ %5661, %5585 ], [ %7286, %__nv_exp2f.exit1417 ] + %8502 = phi float [ %5662, %5585 ], [ %7287, %__nv_exp2f.exit1417 ] + %8503 = phi float [ %5663, %5585 ], [ %7288, %__nv_exp2f.exit1417 ] + %8504 = phi float [ %5664, %5585 ], [ %7289, %__nv_exp2f.exit1417 ] + %8505 = phi float [ %5665, %5585 ], [ %7290, %__nv_exp2f.exit1417 ] + %8506 = phi float [ %5666, %5585 ], [ %7291, %__nv_exp2f.exit1417 ] + %8507 = phi float [ %5667, %5585 ], [ %7292, %__nv_exp2f.exit1417 ] + %8508 = phi float [ %5668, %5585 ], [ %7293, %__nv_exp2f.exit1417 ] + %8509 = phi float [ %5669, %5585 ], [ %7294, %__nv_exp2f.exit1417 ] + %8510 = phi float [ %5670, %5585 ], [ %7295, %__nv_exp2f.exit1417 ] + %8511 = phi float [ %5671, %5585 ], [ %7296, %__nv_exp2f.exit1417 ] + %8512 = phi float [ %5672, %5585 ], [ %7297, %__nv_exp2f.exit1417 ] + %8513 = phi float [ %5673, %5585 ], [ %7298, %__nv_exp2f.exit1417 ] + %8514 = phi float [ %5674, %5585 ], [ %7299, %__nv_exp2f.exit1417 ] + %8515 = phi float [ %5675, %5585 ], [ %7300, %__nv_exp2f.exit1417 ] + %8516 = phi float [ %5676, %5585 ], [ %7301, %__nv_exp2f.exit1417 ] + %8517 = phi float [ %5677, %5585 ], [ %7302, %__nv_exp2f.exit1417 ] + %8518 = phi float [ %5678, %5585 ], [ %7303, %__nv_exp2f.exit1417 ] + %8519 = phi float [ %5679, %5585 ], [ %7304, %__nv_exp2f.exit1417 ] + %8520 = phi float [ %5680, %5585 ], [ %7305, %__nv_exp2f.exit1417 ] + %8521 = phi float [ %5681, %5585 ], [ %7306, %__nv_exp2f.exit1417 ] + %8522 = phi float [ %5682, %5585 ], [ %7307, %__nv_exp2f.exit1417 ] + %8523 = phi float [ %5683, %5585 ], [ %7308, %__nv_exp2f.exit1417 ] + %8524 = phi float [ %5684, %5585 ], [ %7309, %__nv_exp2f.exit1417 ] + %8525 = phi float [ %5685, %5585 ], [ %7310, %__nv_exp2f.exit1417 ] + %8526 = phi float [ %5686, %5585 ], [ %7311, %__nv_exp2f.exit1417 ] + %8527 = phi float [ %5687, %5585 ], [ %7312, %__nv_exp2f.exit1417 ] + %8528 = phi float [ %5688, %5585 ], [ %7313, %__nv_exp2f.exit1417 ] + %8529 = phi float [ %5689, %5585 ], [ %7314, %__nv_exp2f.exit1417 ] + %8530 = phi float [ %5690, %5585 ], [ %7315, %__nv_exp2f.exit1417 ] + %8531 = phi float [ %5691, %5585 ], [ %7316, %__nv_exp2f.exit1417 ] + %8532 = phi float [ %5692, %5585 ], [ %7317, %__nv_exp2f.exit1417 ] + %8533 = phi float [ %5693, %5585 ], [ %7318, %__nv_exp2f.exit1417 ] + %8534 = phi float [ %5694, %5585 ], [ %7319, %__nv_exp2f.exit1417 ] + %8535 = phi float [ %5695, %5585 ], [ %7320, %__nv_exp2f.exit1417 ] + %8536 = phi float [ %5696, %5585 ], [ %7321, %__nv_exp2f.exit1417 ] + %8537 = phi float [ %5697, %5585 ], [ %7322, %__nv_exp2f.exit1417 ] + %8538 = phi float [ %5698, %5585 ], [ %7323, %__nv_exp2f.exit1417 ] + %8539 = phi float [ %5699, %5585 ], [ %7324, %__nv_exp2f.exit1417 ] + %8540 = phi float [ %5700, %5585 ], [ %7325, %__nv_exp2f.exit1417 ] + %8541 = phi float [ %5701, %5585 ], [ %7326, %__nv_exp2f.exit1417 ] + %8542 = phi float [ %5702, %5585 ], [ %7327, %__nv_exp2f.exit1417 ] + %8543 = phi float [ %5703, %5585 ], [ %7328, %__nv_exp2f.exit1417 ] + %8544 = phi float [ %5704, %5585 ], [ %7329, %__nv_exp2f.exit1417 ] + %8545 = phi float [ %5705, %5585 ], [ %7330, %__nv_exp2f.exit1417 ] + %8546 = phi float [ %5706, %5585 ], [ %7331, %__nv_exp2f.exit1417 ] + %8547 = phi float [ %5707, %5585 ], [ %7332, %__nv_exp2f.exit1417 ] + %8548 = phi float [ %5708, %5585 ], [ %7333, %__nv_exp2f.exit1417 ] + %8549 = phi float [ %5709, %5585 ], [ %7334, %__nv_exp2f.exit1417 ] + %8550 = phi float [ %5710, %5585 ], [ %7335, %__nv_exp2f.exit1417 ] + %8551 = phi float [ %5711, %5585 ], [ %7336, %__nv_exp2f.exit1417 ] + %8552 = phi float [ %5712, %5585 ], [ %7337, %__nv_exp2f.exit1417 ] + %8553 = phi float [ %5713, %5585 ], [ %7338, %__nv_exp2f.exit1417 ] + %8554 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %8490, float %8491, float %8492, float %8493, float %8494, float %8495, float %8496, float %8497, float %8498, float %8499, float %8500, float %8501, float %8502, float %8503, float %8504, float %8505, float %8506, float %8507, float %8508, float %8509, float %8510, float %8511, float %8512, float %8513, float %8514, float %8515, float %8516, float %8517, float %8518, float %8519, float %8520, float %8521, float %8522, float %8523, float %8524, float %8525, float %8526, float %8527, float %8528, float %8529, float %8530, float %8531, float %8532, float %8533, float %8534, float %8535, float %8536, float %8537, float %8538, float %8539, float %8540, float %8541, float %8542, float %8543, float %8544, float %8545, float %8546, float %8547, float %8548, float %8549, float %8550, float %8551, float %8552, float %8553, float %8426, float %8427, float %8428, float %8429, float %8430, float %8431, float %8432, float %8433, float %8434, float %8435, float %8436, float %8437, float %8438, float %8439, float %8440, float %8441, float %8442, float %8443, float %8444, float %8445, float %8446, float %8447, float %8448, float %8449, float %8450, float %8451, float %8452, float %8453, float %8454, float %8455, float %8456, float %8457, float %8458, float %8459, float %8460, float %8461, float %8462, float %8463, float %8464, float %8465, float %8466, float %8467, float %8468, float %8469, float %8470, float %8471, float %8472, float %8473, float %8474, float %8475, float %8476, float %8477, float %8478, float %8479, float %8480, float %8481, float %8482, float %8483, float %8484, float %8485, float %8486, float %8487, float %8488, float %8489) #3, !dbg !276 + %8555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 0, !dbg !276 + %8556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 1, !dbg !276 + %8557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 2, !dbg !276 + %8558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 3, !dbg !276 + %8559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 4, !dbg !276 + %8560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 5, !dbg !276 + %8561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 6, !dbg !276 + %8562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 7, !dbg !276 + %8563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 8, !dbg !276 + %8564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 9, !dbg !276 + %8565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 10, !dbg !276 + %8566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 11, !dbg !276 + %8567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 12, !dbg !276 + %8568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 13, !dbg !276 + %8569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 14, !dbg !276 + %8570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 15, !dbg !276 + %8571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 16, !dbg !276 + %8572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 17, !dbg !276 + %8573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 18, !dbg !276 + %8574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 19, !dbg !276 + %8575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 20, !dbg !276 + %8576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 21, !dbg !276 + %8577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 22, !dbg !276 + %8578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 23, !dbg !276 + %8579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 24, !dbg !276 + %8580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 25, !dbg !276 + %8581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 26, !dbg !276 + %8582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 27, !dbg !276 + %8583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 28, !dbg !276 + %8584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 29, !dbg !276 + %8585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 30, !dbg !276 + %8586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 31, !dbg !276 + %8587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 32, !dbg !276 + %8588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 33, !dbg !276 + %8589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 34, !dbg !276 + %8590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 35, !dbg !276 + %8591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 36, !dbg !276 + %8592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 37, !dbg !276 + %8593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 38, !dbg !276 + %8594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 39, !dbg !276 + %8595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 40, !dbg !276 + %8596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 41, !dbg !276 + %8597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 42, !dbg !276 + %8598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 43, !dbg !276 + %8599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 44, !dbg !276 + %8600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 45, !dbg !276 + %8601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 46, !dbg !276 + %8602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 47, !dbg !276 + %8603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 48, !dbg !276 + %8604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 49, !dbg !276 + %8605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 50, !dbg !276 + %8606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 51, !dbg !276 + %8607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 52, !dbg !276 + %8608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 53, !dbg !276 + %8609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 54, !dbg !276 + %8610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 55, !dbg !276 + %8611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 56, !dbg !276 + %8612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 57, !dbg !276 + %8613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 58, !dbg !276 + %8614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 59, !dbg !276 + %8615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 60, !dbg !276 + %8616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 61, !dbg !276 + %8617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 62, !dbg !276 + %8618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 63, !dbg !276 + %8619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 64, !dbg !276 + %8620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 65, !dbg !276 + %8621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 66, !dbg !276 + %8622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 67, !dbg !276 + %8623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 68, !dbg !276 + %8624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 69, !dbg !276 + %8625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 70, !dbg !276 + %8626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 71, !dbg !276 + %8627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 72, !dbg !276 + %8628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 73, !dbg !276 + %8629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 74, !dbg !276 + %8630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 75, !dbg !276 + %8631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 76, !dbg !276 + %8632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 77, !dbg !276 + %8633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 78, !dbg !276 + %8634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 79, !dbg !276 + %8635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 80, !dbg !276 + %8636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 81, !dbg !276 + %8637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 82, !dbg !276 + %8638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 83, !dbg !276 + %8639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 84, !dbg !276 + %8640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 85, !dbg !276 + %8641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 86, !dbg !276 + %8642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 87, !dbg !276 + %8643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 88, !dbg !276 + %8644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 89, !dbg !276 + %8645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 90, !dbg !276 + %8646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 91, !dbg !276 + %8647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 92, !dbg !276 + %8648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 93, !dbg !276 + %8649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 94, !dbg !276 + %8650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 95, !dbg !276 + %8651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 96, !dbg !276 + %8652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 97, !dbg !276 + %8653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 98, !dbg !276 + %8654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 99, !dbg !276 + %8655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 100, !dbg !276 + %8656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 101, !dbg !276 + %8657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 102, !dbg !276 + %8658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 103, !dbg !276 + %8659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 104, !dbg !276 + %8660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 105, !dbg !276 + %8661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 106, !dbg !276 + %8662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 107, !dbg !276 + %8663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 108, !dbg !276 + %8664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 109, !dbg !276 + %8665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 110, !dbg !276 + %8666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 111, !dbg !276 + %8667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 112, !dbg !276 + %8668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 113, !dbg !276 + %8669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 114, !dbg !276 + %8670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 115, !dbg !276 + %8671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 116, !dbg !276 + %8672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 117, !dbg !276 + %8673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 118, !dbg !276 + %8674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 119, !dbg !276 + %8675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 120, !dbg !276 + %8676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 121, !dbg !276 + %8677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 122, !dbg !276 + %8678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 123, !dbg !276 + %8679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 124, !dbg !276 + %8680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 125, !dbg !276 + %8681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 126, !dbg !276 + %8682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8554, 127, !dbg !276 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !276 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !276 + %8683 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5396, !dbg !322 + %8684 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5397, !dbg !322 + %8685 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5398, !dbg !322 + %8686 = getelementptr bfloat, ptr addrspace(1) %5726, i64 %5399, !dbg !322 + %8687 = getelementptr bfloat, ptr addrspace(1) %8683, i64 %4776, !dbg !323 + %8688 = getelementptr bfloat, ptr addrspace(1) %8684, i64 %4776, !dbg !323 + %8689 = getelementptr bfloat, ptr addrspace(1) %8685, i64 %4776, !dbg !323 + %8690 = getelementptr bfloat, ptr addrspace(1) %8686, i64 %4776, !dbg !323 + %8691 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5400, !dbg !324 + %8692 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5401, !dbg !324 + %8693 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5402, !dbg !324 + %8694 = getelementptr bfloat, ptr addrspace(1) %5727, i64 %5403, !dbg !324 + %8695 = getelementptr bfloat, ptr addrspace(1) %8691, i64 %4776, !dbg !325 + %8696 = getelementptr bfloat, ptr addrspace(1) %8692, i64 %4776, !dbg !325 + %8697 = getelementptr bfloat, ptr addrspace(1) %8693, i64 %4776, !dbg !325 + %8698 = getelementptr bfloat, ptr addrspace(1) %8694, i64 %4776, !dbg !325 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5112, ptr addrspace(1) %8687, i32 %5413) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5115, ptr addrspace(1) %8688, i32 %5414) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5118, ptr addrspace(1) %8689, i32 %5415) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5121, ptr addrspace(1) %8690, i32 %5416) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + %8699 = getelementptr float, ptr addrspace(1) %5728, i64 %5433, !dbg !327 + %8700 = getelementptr float, ptr addrspace(1) %5728, i64 %5434, !dbg !327 + %8701 = getelementptr float, ptr addrspace(1) %5728, i64 %5435, !dbg !327 + %8702 = getelementptr float, ptr addrspace(1) %5728, i64 %5436, !dbg !327 + %8703 = getelementptr float, ptr addrspace(1) %5728, i64 %5437, !dbg !327 + %8704 = getelementptr float, ptr addrspace(1) %5728, i64 %5438, !dbg !327 + %8705 = getelementptr float, ptr addrspace(1) %5728, i64 %5439, !dbg !327 + %8706 = getelementptr float, ptr addrspace(1) %5728, i64 %5440, !dbg !327 + %8707 = getelementptr float, ptr addrspace(1) %5728, i64 %5441, !dbg !327 + %8708 = getelementptr float, ptr addrspace(1) %5728, i64 %5442, !dbg !327 + %8709 = getelementptr float, ptr addrspace(1) %5728, i64 %5443, !dbg !327 + %8710 = getelementptr float, ptr addrspace(1) %5728, i64 %5444, !dbg !327 + %8711 = getelementptr float, ptr addrspace(1) %5728, i64 %5445, !dbg !327 + %8712 = getelementptr float, ptr addrspace(1) %5728, i64 %5446, !dbg !327 + %8713 = getelementptr float, ptr addrspace(1) %5728, i64 %5447, !dbg !327 + %8714 = getelementptr float, ptr addrspace(1) %5728, i64 %5448, !dbg !327 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5190, ptr addrspace(1) %8699, i32 %5465, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5193, ptr addrspace(1) %8700, i32 %5466, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5196, ptr addrspace(1) %8701, i32 %5467, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5199, ptr addrspace(1) %8702, i32 %5468, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5202, ptr addrspace(1) %8703, i32 %5469, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5205, ptr addrspace(1) %8704, i32 %5470, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5208, ptr addrspace(1) %8705, i32 %5471, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5211, ptr addrspace(1) %8706, i32 %5472, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5214, ptr addrspace(1) %8707, i32 %5473, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5217, ptr addrspace(1) %8708, i32 %5474, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5220, ptr addrspace(1) %8709, i32 %5475, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5223, ptr addrspace(1) %8710, i32 %5476, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5226, ptr addrspace(1) %8711, i32 %5477, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5229, ptr addrspace(1) %8712, i32 %5478, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5232, ptr addrspace(1) %8713, i32 %5479, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5235, ptr addrspace(1) %8714, i32 %5480, i1 %5188) #3, !dbg !328 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !328 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5237, ptr addrspace(1) %8695, i32 %5413) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5238, ptr addrspace(1) %8696, i32 %5414) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5239, ptr addrspace(1) %8697, i32 %5415) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5240, ptr addrspace(1) %8698, i32 %5416) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + %8715 = getelementptr float, ptr addrspace(1) %5729, i64 %5433, !dbg !330 + %8716 = getelementptr float, ptr addrspace(1) %5729, i64 %5434, !dbg !330 + %8717 = getelementptr float, ptr addrspace(1) %5729, i64 %5435, !dbg !330 + %8718 = getelementptr float, ptr addrspace(1) %5729, i64 %5436, !dbg !330 + %8719 = getelementptr float, ptr addrspace(1) %5729, i64 %5437, !dbg !330 + %8720 = getelementptr float, ptr addrspace(1) %5729, i64 %5438, !dbg !330 + %8721 = getelementptr float, ptr addrspace(1) %5729, i64 %5439, !dbg !330 + %8722 = getelementptr float, ptr addrspace(1) %5729, i64 %5440, !dbg !330 + %8723 = getelementptr float, ptr addrspace(1) %5729, i64 %5441, !dbg !330 + %8724 = getelementptr float, ptr addrspace(1) %5729, i64 %5442, !dbg !330 + %8725 = getelementptr float, ptr addrspace(1) %5729, i64 %5443, !dbg !330 + %8726 = getelementptr float, ptr addrspace(1) %5729, i64 %5444, !dbg !330 + %8727 = getelementptr float, ptr addrspace(1) %5729, i64 %5445, !dbg !330 + %8728 = getelementptr float, ptr addrspace(1) %5729, i64 %5446, !dbg !330 + %8729 = getelementptr float, ptr addrspace(1) %5729, i64 %5447, !dbg !330 + %8730 = getelementptr float, ptr addrspace(1) %5729, i64 %5448, !dbg !330 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5241, ptr addrspace(1) %8715, i32 %5465, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5242, ptr addrspace(1) %8716, i32 %5466, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5243, ptr addrspace(1) %8717, i32 %5467, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5244, ptr addrspace(1) %8718, i32 %5468, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5245, ptr addrspace(1) %8719, i32 %5469, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5246, ptr addrspace(1) %8720, i32 %5470, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5247, ptr addrspace(1) %8721, i32 %5471, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5248, ptr addrspace(1) %8722, i32 %5472, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5249, ptr addrspace(1) %8723, i32 %5473, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5250, ptr addrspace(1) %8724, i32 %5474, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5251, ptr addrspace(1) %8725, i32 %5475, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5252, ptr addrspace(1) %8726, i32 %5476, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5253, ptr addrspace(1) %8727, i32 %5477, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5254, ptr addrspace(1) %8728, i32 %5478, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5255, ptr addrspace(1) %8729, i32 %5479, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5256, ptr addrspace(1) %8730, i32 %5480, i1 %5188) #3, !dbg !331 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !331 + %8731 = getelementptr i8, ptr addrspace(1) %8687, i64 524288, !dbg !332 + %8732 = getelementptr i8, ptr addrspace(1) %8688, i64 524288, !dbg !332 + %8733 = getelementptr i8, ptr addrspace(1) %8689, i64 524288, !dbg !332 + %8734 = getelementptr i8, ptr addrspace(1) %8690, i64 524288, !dbg !332 + %8735 = getelementptr i8, ptr addrspace(1) %8695, i64 16384, !dbg !333 + %8736 = getelementptr i8, ptr addrspace(1) %8696, i64 16384, !dbg !333 + %8737 = getelementptr i8, ptr addrspace(1) %8697, i64 16384, !dbg !333 + %8738 = getelementptr i8, ptr addrspace(1) %8698, i64 16384, !dbg !333 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5286, ptr addrspace(1) %8731, i32 %5510) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5288, ptr addrspace(1) %8732, i32 %5511) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5290, ptr addrspace(1) %8733, i32 %5512) #3, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5292, ptr addrspace(1) %8734, i32 %5513) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + %8739 = getelementptr float, ptr addrspace(1) %5728, i64 %5530, !dbg !327 + %8740 = getelementptr float, ptr addrspace(1) %5728, i64 %5531, !dbg !327 + %8741 = getelementptr float, ptr addrspace(1) %5728, i64 %5532, !dbg !327 + %8742 = getelementptr float, ptr addrspace(1) %5728, i64 %5533, !dbg !327 + %8743 = getelementptr float, ptr addrspace(1) %5728, i64 %5534, !dbg !327 + %8744 = getelementptr float, ptr addrspace(1) %5728, i64 %5535, !dbg !327 + %8745 = getelementptr float, ptr addrspace(1) %5728, i64 %5536, !dbg !327 + %8746 = getelementptr float, ptr addrspace(1) %5728, i64 %5537, !dbg !327 + %8747 = getelementptr float, ptr addrspace(1) %5728, i64 %5538, !dbg !327 + %8748 = getelementptr float, ptr addrspace(1) %5728, i64 %5539, !dbg !327 + %8749 = getelementptr float, ptr addrspace(1) %5728, i64 %5540, !dbg !327 + %8750 = getelementptr float, ptr addrspace(1) %5728, i64 %5541, !dbg !327 + %8751 = getelementptr float, ptr addrspace(1) %5728, i64 %5542, !dbg !327 + %8752 = getelementptr float, ptr addrspace(1) %5728, i64 %5543, !dbg !327 + %8753 = getelementptr float, ptr addrspace(1) %5728, i64 %5544, !dbg !327 + %8754 = getelementptr float, ptr addrspace(1) %5728, i64 %5545, !dbg !327 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5342, ptr addrspace(1) %8739, i32 %5562, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5344, ptr addrspace(1) %8740, i32 %5563, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5346, ptr addrspace(1) %8741, i32 %5564, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5348, ptr addrspace(1) %8742, i32 %5565, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5350, ptr addrspace(1) %8743, i32 %5566, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5352, ptr addrspace(1) %8744, i32 %5567, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5354, ptr addrspace(1) %8745, i32 %5568, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5356, ptr addrspace(1) %8746, i32 %5569, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5358, ptr addrspace(1) %8747, i32 %5570, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5360, ptr addrspace(1) %8748, i32 %5571, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5362, ptr addrspace(1) %8749, i32 %5572, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5364, ptr addrspace(1) %8750, i32 %5573, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5366, ptr addrspace(1) %8751, i32 %5574, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5368, ptr addrspace(1) %8752, i32 %5575, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5370, ptr addrspace(1) %8753, i32 %5576, i1 %5188) #3, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5372, ptr addrspace(1) %8754, i32 %5577, i1 %5188) #3, !dbg !328 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !328 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5374, ptr addrspace(1) %8735, i32 %5510) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5375, ptr addrspace(1) %8736, i32 %5511) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5376, ptr addrspace(1) %8737, i32 %5512) #3, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5377, ptr addrspace(1) %8738, i32 %5513) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + %8755 = getelementptr float, ptr addrspace(1) %5729, i64 %5530, !dbg !330 + %8756 = getelementptr float, ptr addrspace(1) %5729, i64 %5531, !dbg !330 + %8757 = getelementptr float, ptr addrspace(1) %5729, i64 %5532, !dbg !330 + %8758 = getelementptr float, ptr addrspace(1) %5729, i64 %5533, !dbg !330 + %8759 = getelementptr float, ptr addrspace(1) %5729, i64 %5534, !dbg !330 + %8760 = getelementptr float, ptr addrspace(1) %5729, i64 %5535, !dbg !330 + %8761 = getelementptr float, ptr addrspace(1) %5729, i64 %5536, !dbg !330 + %8762 = getelementptr float, ptr addrspace(1) %5729, i64 %5537, !dbg !330 + %8763 = getelementptr float, ptr addrspace(1) %5729, i64 %5538, !dbg !330 + %8764 = getelementptr float, ptr addrspace(1) %5729, i64 %5539, !dbg !330 + %8765 = getelementptr float, ptr addrspace(1) %5729, i64 %5540, !dbg !330 + %8766 = getelementptr float, ptr addrspace(1) %5729, i64 %5541, !dbg !330 + %8767 = getelementptr float, ptr addrspace(1) %5729, i64 %5542, !dbg !330 + %8768 = getelementptr float, ptr addrspace(1) %5729, i64 %5543, !dbg !330 + %8769 = getelementptr float, ptr addrspace(1) %5729, i64 %5544, !dbg !330 + %8770 = getelementptr float, ptr addrspace(1) %5729, i64 %5545, !dbg !330 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5378, ptr addrspace(1) %8755, i32 %5562, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5379, ptr addrspace(1) %8756, i32 %5563, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5380, ptr addrspace(1) %8757, i32 %5564, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5381, ptr addrspace(1) %8758, i32 %5565, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5382, ptr addrspace(1) %8759, i32 %5566, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5383, ptr addrspace(1) %8760, i32 %5567, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5384, ptr addrspace(1) %8761, i32 %5568, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5385, ptr addrspace(1) %8762, i32 %5569, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5386, ptr addrspace(1) %8763, i32 %5570, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5387, ptr addrspace(1) %8764, i32 %5571, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5388, ptr addrspace(1) %8765, i32 %5572, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5389, ptr addrspace(1) %8766, i32 %5573, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5390, ptr addrspace(1) %8767, i32 %5574, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5391, ptr addrspace(1) %8768, i32 %5575, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5392, ptr addrspace(1) %8769, i32 %5576, i1 %5188) #3, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5393, ptr addrspace(1) %8770, i32 %5577, i1 %5188) #3, !dbg !331 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !331 + br i1 %5404, label %.lr.ph1873, label %._crit_edge1874, !dbg !334 + +.lr.ph1873: ; preds = %._crit_edge1701, %__nv_exp2f.exit1321 + %.pn605.pn1871 = phi i32 [ %.pn6051855, %__nv_exp2f.exit1321 ], [ %5078, %._crit_edge1701 ] + %.pn607.pn1870 = phi i32 [ %.pn6071854, %__nv_exp2f.exit1321 ], [ %5076, %._crit_edge1701 ] + %.pn609.pn1869 = phi i32 [ %.pn6091853, %__nv_exp2f.exit1321 ], [ %5074, %._crit_edge1701 ] + %.pn611.pn1868 = phi i32 [ %.pn6111852, %__nv_exp2f.exit1321 ], [ %5072, %._crit_edge1701 ] + %.pn613.pn1867 = phi i32 [ %.pn6131851, %__nv_exp2f.exit1321 ], [ %5070, %._crit_edge1701 ] + %.pn615.pn1866 = phi i32 [ %.pn6151850, %__nv_exp2f.exit1321 ], [ %5068, %._crit_edge1701 ] + %.pn617.pn1865 = phi i32 [ %.pn6171849, %__nv_exp2f.exit1321 ], [ %5066, %._crit_edge1701 ] + %.pn619.pn1864 = phi i32 [ %.pn6191848, %__nv_exp2f.exit1321 ], [ %5064, %._crit_edge1701 ] + %.pn621.pn1863 = phi i32 [ %.pn6211847, %__nv_exp2f.exit1321 ], [ %5062, %._crit_edge1701 ] + %.pn623.pn1862 = phi i32 [ %.pn6231846, %__nv_exp2f.exit1321 ], [ %5060, %._crit_edge1701 ] + %.pn625.pn1861 = phi i32 [ %.pn6251845, %__nv_exp2f.exit1321 ], [ %5058, %._crit_edge1701 ] + %.pn627.pn1860 = phi i32 [ %.pn6271844, %__nv_exp2f.exit1321 ], [ %5056, %._crit_edge1701 ] + %.pn629.pn1859 = phi i32 [ %.pn6291843, %__nv_exp2f.exit1321 ], [ %5054, %._crit_edge1701 ] + %.pn631.pn1858 = phi i32 [ %.pn6311842, %__nv_exp2f.exit1321 ], [ %5053, %._crit_edge1701 ] + %.pn633.pn1857 = phi i32 [ %.pn6331841, %__nv_exp2f.exit1321 ], [ %5052, %._crit_edge1701 ] + %.pn635.pn1856 = phi i32 [ %.pn6351840, %__nv_exp2f.exit1321 ], [ %5051, %._crit_edge1701 ] + %8771 = phi i32 [ %8788, %__nv_exp2f.exit1321 ], [ -1, %._crit_edge1701 ] + %8772 = phi i32 [ %10810, %__nv_exp2f.exit1321 ], [ 1, %._crit_edge1701 ] + %8773 = phi i32 [ %8791, %__nv_exp2f.exit1321 ], [ -1, %._crit_edge1701 ] + %8774 = phi i32 [ %10813, %__nv_exp2f.exit1321 ], [ 1, %._crit_edge1701 ] + %.pn6051855 = phi i32 [ %10799, %__nv_exp2f.exit1321 ], [ %5497, %._crit_edge1701 ] + %.pn6071854 = phi i32 [ %10798, %__nv_exp2f.exit1321 ], [ %5496, %._crit_edge1701 ] + %.pn6091853 = phi i32 [ %10797, %__nv_exp2f.exit1321 ], [ %5495, %._crit_edge1701 ] + %.pn6111852 = phi i32 [ %10796, %__nv_exp2f.exit1321 ], [ %5494, %._crit_edge1701 ] + %.pn6131851 = phi i32 [ %10795, %__nv_exp2f.exit1321 ], [ %5493, %._crit_edge1701 ] + %.pn6151850 = phi i32 [ %10794, %__nv_exp2f.exit1321 ], [ %5492, %._crit_edge1701 ] + %.pn6171849 = phi i32 [ %10793, %__nv_exp2f.exit1321 ], [ %5491, %._crit_edge1701 ] + %.pn6191848 = phi i32 [ %10792, %__nv_exp2f.exit1321 ], [ %5490, %._crit_edge1701 ] + %.pn6211847 = phi i32 [ %10791, %__nv_exp2f.exit1321 ], [ %5489, %._crit_edge1701 ] + %.pn6231846 = phi i32 [ %10790, %__nv_exp2f.exit1321 ], [ %5488, %._crit_edge1701 ] + %.pn6251845 = phi i32 [ %10789, %__nv_exp2f.exit1321 ], [ %5487, %._crit_edge1701 ] + %.pn6271844 = phi i32 [ %10788, %__nv_exp2f.exit1321 ], [ %5486, %._crit_edge1701 ] + %.pn6291843 = phi i32 [ %10787, %__nv_exp2f.exit1321 ], [ %5485, %._crit_edge1701 ] + %.pn6311842 = phi i32 [ %10786, %__nv_exp2f.exit1321 ], [ %5484, %._crit_edge1701 ] + %.pn6331841 = phi i32 [ %10785, %__nv_exp2f.exit1321 ], [ %5483, %._crit_edge1701 ] + %.pn6351840 = phi i32 [ %10784, %__nv_exp2f.exit1321 ], [ %5482, %._crit_edge1701 ] + %8775 = phi i32 [ %10804, %__nv_exp2f.exit1321 ], [ %5498, %._crit_edge1701 ] + %8776 = phi i32 [ %10805, %__nv_exp2f.exit1321 ], [ %5499, %._crit_edge1701 ] + %8777 = phi i32 [ %10806, %__nv_exp2f.exit1321 ], [ %5500, %._crit_edge1701 ] + %8778 = phi i32 [ %10807, %__nv_exp2f.exit1321 ], [ %5501, %._crit_edge1701 ] + %.pn5551839 = phi ptr addrspace(1) [ %10783, %__nv_exp2f.exit1321 ], [ %8738, %._crit_edge1701 ] + %.pn5711838 = phi ptr addrspace(1) [ %10782, %__nv_exp2f.exit1321 ], [ %8737, %._crit_edge1701 ] + %.pn5871837 = phi ptr addrspace(1) [ %10781, %__nv_exp2f.exit1321 ], [ %8736, %._crit_edge1701 ] + %.pn6031836 = phi ptr addrspace(1) [ %10780, %__nv_exp2f.exit1321 ], [ %8735, %._crit_edge1701 ] + %8779 = phi i32 [ %10800, %__nv_exp2f.exit1321 ], [ %5498, %._crit_edge1701 ] + %8780 = phi i32 [ %10801, %__nv_exp2f.exit1321 ], [ %5499, %._crit_edge1701 ] + %8781 = phi i32 [ %10802, %__nv_exp2f.exit1321 ], [ %5500, %._crit_edge1701 ] + %8782 = phi i32 [ %10803, %__nv_exp2f.exit1321 ], [ %5501, %._crit_edge1701 ] + %.pn4911835 = phi ptr addrspace(1) [ %10777, %__nv_exp2f.exit1321 ], [ %8734, %._crit_edge1701 ] + %.pn5071834 = phi ptr addrspace(1) [ %10776, %__nv_exp2f.exit1321 ], [ %8733, %._crit_edge1701 ] + %.pn5231833 = phi ptr addrspace(1) [ %10775, %__nv_exp2f.exit1321 ], [ %8732, %._crit_edge1701 ] + %.pn5391832 = phi ptr addrspace(1) [ %10774, %__nv_exp2f.exit1321 ], [ %8731, %._crit_edge1701 ] + %.pn3491831 = phi float [ %9895, %__nv_exp2f.exit1321 ], [ %8618, %._crit_edge1701 ] + %.pn3511830 = phi float [ %9894, %__nv_exp2f.exit1321 ], [ %8617, %._crit_edge1701 ] + %.pn3531829 = phi float [ %9893, %__nv_exp2f.exit1321 ], [ %8616, %._crit_edge1701 ] + %.pn3551828 = phi float [ %9892, %__nv_exp2f.exit1321 ], [ %8615, %._crit_edge1701 ] + %.pn3571827 = phi float [ %9891, %__nv_exp2f.exit1321 ], [ %8614, %._crit_edge1701 ] + %.pn3591826 = phi float [ %9890, %__nv_exp2f.exit1321 ], [ %8613, %._crit_edge1701 ] + %.pn3611825 = phi float [ %9889, %__nv_exp2f.exit1321 ], [ %8612, %._crit_edge1701 ] + %.pn3631824 = phi float [ %9888, %__nv_exp2f.exit1321 ], [ %8611, %._crit_edge1701 ] + %.pn3651823 = phi float [ %9887, %__nv_exp2f.exit1321 ], [ %8610, %._crit_edge1701 ] + %.pn3671822 = phi float [ %9886, %__nv_exp2f.exit1321 ], [ %8609, %._crit_edge1701 ] + %.pn3691821 = phi float [ %9885, %__nv_exp2f.exit1321 ], [ %8608, %._crit_edge1701 ] + %.pn3711820 = phi float [ %9884, %__nv_exp2f.exit1321 ], [ %8607, %._crit_edge1701 ] + %.pn3731819 = phi float [ %9883, %__nv_exp2f.exit1321 ], [ %8606, %._crit_edge1701 ] + %.pn3751818 = phi float [ %9882, %__nv_exp2f.exit1321 ], [ %8605, %._crit_edge1701 ] + %.pn3771817 = phi float [ %9881, %__nv_exp2f.exit1321 ], [ %8604, %._crit_edge1701 ] + %.pn3791816 = phi float [ %9880, %__nv_exp2f.exit1321 ], [ %8603, %._crit_edge1701 ] + %.pn3811815 = phi float [ %9879, %__nv_exp2f.exit1321 ], [ %8602, %._crit_edge1701 ] + %.pn3831814 = phi float [ %9878, %__nv_exp2f.exit1321 ], [ %8601, %._crit_edge1701 ] + %.pn3851813 = phi float [ %9877, %__nv_exp2f.exit1321 ], [ %8600, %._crit_edge1701 ] + %.pn3871812 = phi float [ %9876, %__nv_exp2f.exit1321 ], [ %8599, %._crit_edge1701 ] + %.pn3891811 = phi float [ %9875, %__nv_exp2f.exit1321 ], [ %8598, %._crit_edge1701 ] + %.pn3911810 = phi float [ %9874, %__nv_exp2f.exit1321 ], [ %8597, %._crit_edge1701 ] + %.pn3931809 = phi float [ %9873, %__nv_exp2f.exit1321 ], [ %8596, %._crit_edge1701 ] + %.pn3951808 = phi float [ %9872, %__nv_exp2f.exit1321 ], [ %8595, %._crit_edge1701 ] + %.pn3971807 = phi float [ %9871, %__nv_exp2f.exit1321 ], [ %8594, %._crit_edge1701 ] + %.pn3991806 = phi float [ %9870, %__nv_exp2f.exit1321 ], [ %8593, %._crit_edge1701 ] + %.pn4011805 = phi float [ %9869, %__nv_exp2f.exit1321 ], [ %8592, %._crit_edge1701 ] + %.pn4031804 = phi float [ %9868, %__nv_exp2f.exit1321 ], [ %8591, %._crit_edge1701 ] + %.pn4051803 = phi float [ %9867, %__nv_exp2f.exit1321 ], [ %8590, %._crit_edge1701 ] + %.pn4071802 = phi float [ %9866, %__nv_exp2f.exit1321 ], [ %8589, %._crit_edge1701 ] + %.pn4091801 = phi float [ %9865, %__nv_exp2f.exit1321 ], [ %8588, %._crit_edge1701 ] + %.pn4111800 = phi float [ %9864, %__nv_exp2f.exit1321 ], [ %8587, %._crit_edge1701 ] + %.pn4131799 = phi float [ %9863, %__nv_exp2f.exit1321 ], [ %8586, %._crit_edge1701 ] + %.pn4151798 = phi float [ %9862, %__nv_exp2f.exit1321 ], [ %8585, %._crit_edge1701 ] + %.pn4171797 = phi float [ %9861, %__nv_exp2f.exit1321 ], [ %8584, %._crit_edge1701 ] + %.pn4191796 = phi float [ %9860, %__nv_exp2f.exit1321 ], [ %8583, %._crit_edge1701 ] + %.pn4211795 = phi float [ %9859, %__nv_exp2f.exit1321 ], [ %8582, %._crit_edge1701 ] + %.pn4231794 = phi float [ %9858, %__nv_exp2f.exit1321 ], [ %8581, %._crit_edge1701 ] + %.pn4251793 = phi float [ %9857, %__nv_exp2f.exit1321 ], [ %8580, %._crit_edge1701 ] + %.pn4271792 = phi float [ %9856, %__nv_exp2f.exit1321 ], [ %8579, %._crit_edge1701 ] + %.pn4291791 = phi float [ %9855, %__nv_exp2f.exit1321 ], [ %8578, %._crit_edge1701 ] + %.pn4311790 = phi float [ %9854, %__nv_exp2f.exit1321 ], [ %8577, %._crit_edge1701 ] + %.pn4331789 = phi float [ %9853, %__nv_exp2f.exit1321 ], [ %8576, %._crit_edge1701 ] + %.pn4351788 = phi float [ %9852, %__nv_exp2f.exit1321 ], [ %8575, %._crit_edge1701 ] + %.pn4371787 = phi float [ %9851, %__nv_exp2f.exit1321 ], [ %8574, %._crit_edge1701 ] + %.pn4391786 = phi float [ %9850, %__nv_exp2f.exit1321 ], [ %8573, %._crit_edge1701 ] + %.pn4411785 = phi float [ %9849, %__nv_exp2f.exit1321 ], [ %8572, %._crit_edge1701 ] + %.pn4431784 = phi float [ %9848, %__nv_exp2f.exit1321 ], [ %8571, %._crit_edge1701 ] + %.pn4451783 = phi float [ %9847, %__nv_exp2f.exit1321 ], [ %8570, %._crit_edge1701 ] + %.pn4471782 = phi float [ %9846, %__nv_exp2f.exit1321 ], [ %8569, %._crit_edge1701 ] + %.pn4491781 = phi float [ %9845, %__nv_exp2f.exit1321 ], [ %8568, %._crit_edge1701 ] + %.pn4511780 = phi float [ %9844, %__nv_exp2f.exit1321 ], [ %8567, %._crit_edge1701 ] + %.pn4531779 = phi float [ %9843, %__nv_exp2f.exit1321 ], [ %8566, %._crit_edge1701 ] + %.pn4551778 = phi float [ %9842, %__nv_exp2f.exit1321 ], [ %8565, %._crit_edge1701 ] + %.pn4571777 = phi float [ %9841, %__nv_exp2f.exit1321 ], [ %8564, %._crit_edge1701 ] + %.pn4591776 = phi float [ %9840, %__nv_exp2f.exit1321 ], [ %8563, %._crit_edge1701 ] + %.pn4611775 = phi float [ %9839, %__nv_exp2f.exit1321 ], [ %8562, %._crit_edge1701 ] + %.pn4631774 = phi float [ %9838, %__nv_exp2f.exit1321 ], [ %8561, %._crit_edge1701 ] + %.pn4651773 = phi float [ %9837, %__nv_exp2f.exit1321 ], [ %8560, %._crit_edge1701 ] + %.pn4671772 = phi float [ %9836, %__nv_exp2f.exit1321 ], [ %8559, %._crit_edge1701 ] + %.pn4691771 = phi float [ %9835, %__nv_exp2f.exit1321 ], [ %8558, %._crit_edge1701 ] + %.pn4711770 = phi float [ %9834, %__nv_exp2f.exit1321 ], [ %8557, %._crit_edge1701 ] + %.pn4731769 = phi float [ %9833, %__nv_exp2f.exit1321 ], [ %8556, %._crit_edge1701 ] + %.pn4751768 = phi float [ %9832, %__nv_exp2f.exit1321 ], [ %8555, %._crit_edge1701 ] + %.pn2211767 = phi float [ %10751, %__nv_exp2f.exit1321 ], [ %8682, %._crit_edge1701 ] + %.pn2231766 = phi float [ %10750, %__nv_exp2f.exit1321 ], [ %8681, %._crit_edge1701 ] + %.pn2251765 = phi float [ %10749, %__nv_exp2f.exit1321 ], [ %8680, %._crit_edge1701 ] + %.pn2271764 = phi float [ %10748, %__nv_exp2f.exit1321 ], [ %8679, %._crit_edge1701 ] + %.pn2291763 = phi float [ %10747, %__nv_exp2f.exit1321 ], [ %8678, %._crit_edge1701 ] + %.pn2311762 = phi float [ %10746, %__nv_exp2f.exit1321 ], [ %8677, %._crit_edge1701 ] + %.pn2331761 = phi float [ %10745, %__nv_exp2f.exit1321 ], [ %8676, %._crit_edge1701 ] + %.pn2351760 = phi float [ %10744, %__nv_exp2f.exit1321 ], [ %8675, %._crit_edge1701 ] + %.pn2371759 = phi float [ %10743, %__nv_exp2f.exit1321 ], [ %8674, %._crit_edge1701 ] + %.pn2391758 = phi float [ %10742, %__nv_exp2f.exit1321 ], [ %8673, %._crit_edge1701 ] + %.pn2411757 = phi float [ %10741, %__nv_exp2f.exit1321 ], [ %8672, %._crit_edge1701 ] + %.pn2431756 = phi float [ %10740, %__nv_exp2f.exit1321 ], [ %8671, %._crit_edge1701 ] + %.pn2451755 = phi float [ %10739, %__nv_exp2f.exit1321 ], [ %8670, %._crit_edge1701 ] + %.pn2471754 = phi float [ %10738, %__nv_exp2f.exit1321 ], [ %8669, %._crit_edge1701 ] + %.pn2491753 = phi float [ %10737, %__nv_exp2f.exit1321 ], [ %8668, %._crit_edge1701 ] + %.pn2511752 = phi float [ %10736, %__nv_exp2f.exit1321 ], [ %8667, %._crit_edge1701 ] + %.pn2531751 = phi float [ %10735, %__nv_exp2f.exit1321 ], [ %8666, %._crit_edge1701 ] + %.pn2551750 = phi float [ %10734, %__nv_exp2f.exit1321 ], [ %8665, %._crit_edge1701 ] + %.pn2571749 = phi float [ %10733, %__nv_exp2f.exit1321 ], [ %8664, %._crit_edge1701 ] + %.pn2591748 = phi float [ %10732, %__nv_exp2f.exit1321 ], [ %8663, %._crit_edge1701 ] + %.pn2611747 = phi float [ %10731, %__nv_exp2f.exit1321 ], [ %8662, %._crit_edge1701 ] + %.pn2631746 = phi float [ %10730, %__nv_exp2f.exit1321 ], [ %8661, %._crit_edge1701 ] + %.pn2651745 = phi float [ %10729, %__nv_exp2f.exit1321 ], [ %8660, %._crit_edge1701 ] + %.pn2671744 = phi float [ %10728, %__nv_exp2f.exit1321 ], [ %8659, %._crit_edge1701 ] + %.pn2691743 = phi float [ %10727, %__nv_exp2f.exit1321 ], [ %8658, %._crit_edge1701 ] + %.pn2711742 = phi float [ %10726, %__nv_exp2f.exit1321 ], [ %8657, %._crit_edge1701 ] + %.pn2731741 = phi float [ %10725, %__nv_exp2f.exit1321 ], [ %8656, %._crit_edge1701 ] + %.pn2751740 = phi float [ %10724, %__nv_exp2f.exit1321 ], [ %8655, %._crit_edge1701 ] + %.pn2771739 = phi float [ %10723, %__nv_exp2f.exit1321 ], [ %8654, %._crit_edge1701 ] + %.pn2791738 = phi float [ %10722, %__nv_exp2f.exit1321 ], [ %8653, %._crit_edge1701 ] + %.pn2811737 = phi float [ %10721, %__nv_exp2f.exit1321 ], [ %8652, %._crit_edge1701 ] + %.pn2831736 = phi float [ %10720, %__nv_exp2f.exit1321 ], [ %8651, %._crit_edge1701 ] + %.pn2851735 = phi float [ %10719, %__nv_exp2f.exit1321 ], [ %8650, %._crit_edge1701 ] + %.pn2871734 = phi float [ %10718, %__nv_exp2f.exit1321 ], [ %8649, %._crit_edge1701 ] + %.pn2891733 = phi float [ %10717, %__nv_exp2f.exit1321 ], [ %8648, %._crit_edge1701 ] + %.pn2911732 = phi float [ %10716, %__nv_exp2f.exit1321 ], [ %8647, %._crit_edge1701 ] + %.pn2931731 = phi float [ %10715, %__nv_exp2f.exit1321 ], [ %8646, %._crit_edge1701 ] + %.pn2951730 = phi float [ %10714, %__nv_exp2f.exit1321 ], [ %8645, %._crit_edge1701 ] + %.pn2971729 = phi float [ %10713, %__nv_exp2f.exit1321 ], [ %8644, %._crit_edge1701 ] + %.pn2991728 = phi float [ %10712, %__nv_exp2f.exit1321 ], [ %8643, %._crit_edge1701 ] + %.pn3011727 = phi float [ %10711, %__nv_exp2f.exit1321 ], [ %8642, %._crit_edge1701 ] + %.pn3031726 = phi float [ %10710, %__nv_exp2f.exit1321 ], [ %8641, %._crit_edge1701 ] + %.pn3051725 = phi float [ %10709, %__nv_exp2f.exit1321 ], [ %8640, %._crit_edge1701 ] + %.pn3071724 = phi float [ %10708, %__nv_exp2f.exit1321 ], [ %8639, %._crit_edge1701 ] + %.pn3091723 = phi float [ %10707, %__nv_exp2f.exit1321 ], [ %8638, %._crit_edge1701 ] + %.pn3111722 = phi float [ %10706, %__nv_exp2f.exit1321 ], [ %8637, %._crit_edge1701 ] + %.pn3131721 = phi float [ %10705, %__nv_exp2f.exit1321 ], [ %8636, %._crit_edge1701 ] + %.pn3151720 = phi float [ %10704, %__nv_exp2f.exit1321 ], [ %8635, %._crit_edge1701 ] + %.pn3171719 = phi float [ %10703, %__nv_exp2f.exit1321 ], [ %8634, %._crit_edge1701 ] + %.pn3191718 = phi float [ %10702, %__nv_exp2f.exit1321 ], [ %8633, %._crit_edge1701 ] + %.pn3211717 = phi float [ %10701, %__nv_exp2f.exit1321 ], [ %8632, %._crit_edge1701 ] + %.pn3231716 = phi float [ %10700, %__nv_exp2f.exit1321 ], [ %8631, %._crit_edge1701 ] + %.pn3251715 = phi float [ %10699, %__nv_exp2f.exit1321 ], [ %8630, %._crit_edge1701 ] + %.pn3271714 = phi float [ %10698, %__nv_exp2f.exit1321 ], [ %8629, %._crit_edge1701 ] + %.pn3291713 = phi float [ %10697, %__nv_exp2f.exit1321 ], [ %8628, %._crit_edge1701 ] + %.pn3311712 = phi float [ %10696, %__nv_exp2f.exit1321 ], [ %8627, %._crit_edge1701 ] + %.pn3331711 = phi float [ %10695, %__nv_exp2f.exit1321 ], [ %8626, %._crit_edge1701 ] + %.pn3351710 = phi float [ %10694, %__nv_exp2f.exit1321 ], [ %8625, %._crit_edge1701 ] + %.pn3371709 = phi float [ %10693, %__nv_exp2f.exit1321 ], [ %8624, %._crit_edge1701 ] + %.pn3391708 = phi float [ %10692, %__nv_exp2f.exit1321 ], [ %8623, %._crit_edge1701 ] + %.pn3411707 = phi float [ %10691, %__nv_exp2f.exit1321 ], [ %8622, %._crit_edge1701 ] + %.pn3431706 = phi float [ %10690, %__nv_exp2f.exit1321 ], [ %8621, %._crit_edge1701 ] + %.pn3451705 = phi float [ %10689, %__nv_exp2f.exit1321 ], [ %8620, %._crit_edge1701 ] + %.pn3471704 = phi float [ %10688, %__nv_exp2f.exit1321 ], [ %8619, %._crit_edge1701 ] + %8783 = phi i32 [ %10752, %__nv_exp2f.exit1321 ], [ 0, %._crit_edge1701 ] + %8784 = icmp slt i32 %8783, %5578, !dbg !334 + %8785 = icmp slt i32 %8783, %5579, !dbg !334 + %8786 = add i32 %8771, 1, !dbg !334 + %8787 = icmp sgt i32 %8786, 1, !dbg !334 + %8788 = select i1 %8787, i32 0, i32 %8786, !dbg !334 + %8789 = add i32 %8773, 1, !dbg !334 + %8790 = icmp sgt i32 %8789, 2, !dbg !334 + %8791 = select i1 %8790, i32 0, i32 %8789, !dbg !334 + %8792 = icmp slt i32 %.pn635.pn1856, %17, !dbg !335 + %8793 = icmp slt i32 %.pn633.pn1857, %17, !dbg !335 + %8794 = icmp slt i32 %.pn631.pn1858, %17, !dbg !335 + %8795 = icmp slt i32 %.pn629.pn1859, %17, !dbg !335 + %8796 = icmp slt i32 %.pn627.pn1860, %17, !dbg !335 + %8797 = icmp slt i32 %.pn625.pn1861, %17, !dbg !335 + %8798 = icmp slt i32 %.pn623.pn1862, %17, !dbg !335 + %8799 = icmp slt i32 %.pn621.pn1863, %17, !dbg !335 + %8800 = icmp slt i32 %.pn619.pn1864, %17, !dbg !335 + %8801 = icmp slt i32 %.pn617.pn1865, %17, !dbg !335 + %8802 = icmp slt i32 %.pn615.pn1866, %17, !dbg !335 + %8803 = icmp slt i32 %.pn613.pn1867, %17, !dbg !335 + %8804 = icmp slt i32 %.pn611.pn1868, %17, !dbg !335 + %8805 = icmp slt i32 %.pn609.pn1869, %17, !dbg !335 + %8806 = icmp slt i32 %.pn607.pn1870, %17, !dbg !335 + %8807 = icmp slt i32 %.pn605.pn1871, %17, !dbg !335 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !326 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !326 + %8808 = shl i32 %8791, 13, !dbg !326 + %8809 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %8808, !dbg !326 + %8810 = shl i32 %8788, 6, !dbg !328 + %8811 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %8810, !dbg !328 + %8812 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5189, !dbg !328 + %8813 = load float, ptr addrspace(3) %8812, align 8, !dbg !328 + %8814 = getelementptr inbounds nuw i8, ptr addrspace(3) %8812, i32 4, !dbg !328 + %8815 = load float, ptr addrspace(3) %8814, align 4, !dbg !328 + %8816 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5195, !dbg !328 + %8817 = load float, ptr addrspace(3) %8816, align 8, !dbg !328 + %8818 = getelementptr inbounds nuw i8, ptr addrspace(3) %8816, i32 4, !dbg !328 + %8819 = load float, ptr addrspace(3) %8818, align 4, !dbg !328 + %8820 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5201, !dbg !328 + %8821 = load float, ptr addrspace(3) %8820, align 8, !dbg !328 + %8822 = getelementptr inbounds nuw i8, ptr addrspace(3) %8820, i32 4, !dbg !328 + %8823 = load float, ptr addrspace(3) %8822, align 4, !dbg !328 + %8824 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5207, !dbg !328 + %8825 = load float, ptr addrspace(3) %8824, align 8, !dbg !328 + %8826 = getelementptr inbounds nuw i8, ptr addrspace(3) %8824, i32 4, !dbg !328 + %8827 = load float, ptr addrspace(3) %8826, align 4, !dbg !328 + %8828 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5213, !dbg !328 + %8829 = load float, ptr addrspace(3) %8828, align 8, !dbg !328 + %8830 = getelementptr inbounds nuw i8, ptr addrspace(3) %8828, i32 4, !dbg !328 + %8831 = load float, ptr addrspace(3) %8830, align 4, !dbg !328 + %8832 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5219, !dbg !328 + %8833 = load float, ptr addrspace(3) %8832, align 8, !dbg !328 + %8834 = getelementptr inbounds nuw i8, ptr addrspace(3) %8832, i32 4, !dbg !328 + %8835 = load float, ptr addrspace(3) %8834, align 4, !dbg !328 + %8836 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5225, !dbg !328 + %8837 = load float, ptr addrspace(3) %8836, align 8, !dbg !328 + %8838 = getelementptr inbounds nuw i8, ptr addrspace(3) %8836, i32 4, !dbg !328 + %8839 = load float, ptr addrspace(3) %8838, align 4, !dbg !328 + %8840 = getelementptr inbounds nuw i8, ptr addrspace(3) %8811, i32 %5231, !dbg !328 + %8841 = load float, ptr addrspace(3) %8840, align 8, !dbg !328 + %8842 = getelementptr inbounds nuw i8, ptr addrspace(3) %8840, i32 4, !dbg !328 + %8843 = load float, ptr addrspace(3) %8842, align 4, !dbg !328 + %8844 = fcmp oeq float %8813, 0xFFF0000000000000, !dbg !336 + %8845 = fcmp oeq float %8815, 0xFFF0000000000000, !dbg !336 + %8846 = fcmp oeq float %8817, 0xFFF0000000000000, !dbg !336 + %8847 = fcmp oeq float %8819, 0xFFF0000000000000, !dbg !336 + %8848 = fcmp oeq float %8821, 0xFFF0000000000000, !dbg !336 + %8849 = fcmp oeq float %8823, 0xFFF0000000000000, !dbg !336 + %8850 = fcmp oeq float %8825, 0xFFF0000000000000, !dbg !336 + %8851 = fcmp oeq float %8827, 0xFFF0000000000000, !dbg !336 + %8852 = fcmp oeq float %8829, 0xFFF0000000000000, !dbg !336 + %8853 = fcmp oeq float %8831, 0xFFF0000000000000, !dbg !336 + %8854 = fcmp oeq float %8833, 0xFFF0000000000000, !dbg !336 + %8855 = fcmp oeq float %8835, 0xFFF0000000000000, !dbg !336 + %8856 = fcmp oeq float %8837, 0xFFF0000000000000, !dbg !336 + %8857 = fcmp oeq float %8839, 0xFFF0000000000000, !dbg !336 + %8858 = fcmp oeq float %8841, 0xFFF0000000000000, !dbg !336 + %8859 = fcmp oeq float %8843, 0xFFF0000000000000, !dbg !336 + %8860 = select i1 %8844, float 0.000000e+00, float %8813, !dbg !337 + %8861 = select i1 %8845, float 0.000000e+00, float %8815, !dbg !337 + %8862 = select i1 %8846, float 0.000000e+00, float %8817, !dbg !337 + %8863 = select i1 %8847, float 0.000000e+00, float %8819, !dbg !337 + %8864 = select i1 %8848, float 0.000000e+00, float %8821, !dbg !337 + %8865 = select i1 %8849, float 0.000000e+00, float %8823, !dbg !337 + %8866 = select i1 %8850, float 0.000000e+00, float %8825, !dbg !337 + %8867 = select i1 %8851, float 0.000000e+00, float %8827, !dbg !337 + %8868 = select i1 %8852, float 0.000000e+00, float %8829, !dbg !337 + %8869 = select i1 %8853, float 0.000000e+00, float %8831, !dbg !337 + %8870 = select i1 %8854, float 0.000000e+00, float %8833, !dbg !337 + %8871 = select i1 %8855, float 0.000000e+00, float %8835, !dbg !337 + %8872 = select i1 %8856, float 0.000000e+00, float %8837, !dbg !337 + %8873 = select i1 %8857, float 0.000000e+00, float %8839, !dbg !337 + %8874 = select i1 %8858, float 0.000000e+00, float %8841, !dbg !337 + %8875 = select i1 %8859, float 0.000000e+00, float %8843, !dbg !337 + %8876 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %51, i32 0, i32 31), !dbg !338 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !338 + %8877 = shl i32 %8876, 11, !dbg !338 + %8878 = and i32 %8877, 8192, !dbg !338 + %8879 = add i32 %8878, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %8880 = lshr exact i32 %8879, 4, !dbg !338 + %8881 = and i32 %8880, 16383, !dbg !338 + %8882 = zext nneg i32 %8881 to i64, !dbg !338 + %8883 = or disjoint i64 %8882, 4611686293372403712, !dbg !338 + %8884 = ptrtoint ptr addrspace(3) %8809 to i32, !dbg !338 + %8885 = lshr exact i32 %8884, 4, !dbg !338 + %8886 = and i32 %8885, 16383, !dbg !338 + %8887 = zext nneg i32 %8886 to i64, !dbg !338 + %8888 = or disjoint i64 %8887, 4611686293338849280, !dbg !338 + %8889 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %8883, i64 %8888) #3, !dbg !338 + %8890 = or disjoint i32 %8878, 32, !dbg !338 + %8891 = add i32 %8890, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %8892 = lshr exact i32 %8891, 4, !dbg !338 + %8893 = and i32 %8892, 16383, !dbg !338 + %8894 = zext nneg i32 %8893 to i64, !dbg !338 + %8895 = or disjoint i64 %8894, 4611686293372403712, !dbg !338 + %8896 = add i32 %8884, 32, !dbg !338 + %8897 = lshr exact i32 %8896, 4, !dbg !338 + %8898 = and i32 %8897, 16383, !dbg !338 + %8899 = zext nneg i32 %8898 to i64, !dbg !338 + %8900 = or disjoint i64 %8899, 4611686293338849280, !dbg !338 + %8901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 0, !dbg !338 + %8902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 1, !dbg !338 + %8903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 2, !dbg !338 + %8904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 3, !dbg !338 + %8905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 4, !dbg !338 + %8906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 5, !dbg !338 + %8907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 6, !dbg !338 + %8908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 7, !dbg !338 + %8909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 8, !dbg !338 + %8910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 9, !dbg !338 + %8911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 10, !dbg !338 + %8912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 11, !dbg !338 + %8913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 12, !dbg !338 + %8914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 13, !dbg !338 + %8915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 14, !dbg !338 + %8916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 15, !dbg !338 + %8917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 16, !dbg !338 + %8918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 17, !dbg !338 + %8919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 18, !dbg !338 + %8920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 19, !dbg !338 + %8921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 20, !dbg !338 + %8922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 21, !dbg !338 + %8923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 22, !dbg !338 + %8924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 23, !dbg !338 + %8925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 24, !dbg !338 + %8926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 25, !dbg !338 + %8927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 26, !dbg !338 + %8928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 27, !dbg !338 + %8929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 28, !dbg !338 + %8930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 29, !dbg !338 + %8931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 30, !dbg !338 + %8932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8889, 31, !dbg !338 + %8933 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8901, float %8902, float %8903, float %8904, float %8905, float %8906, float %8907, float %8908, float %8909, float %8910, float %8911, float %8912, float %8913, float %8914, float %8915, float %8916, float %8917, float %8918, float %8919, float %8920, float %8921, float %8922, float %8923, float %8924, float %8925, float %8926, float %8927, float %8928, float %8929, float %8930, float %8931, float %8932, i64 %8895, i64 %8900, i1 true) #3, !dbg !338 + %8934 = or disjoint i32 %8878, 64, !dbg !338 + %8935 = add i32 %8934, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %8936 = lshr exact i32 %8935, 4, !dbg !338 + %8937 = and i32 %8936, 16383, !dbg !338 + %8938 = zext nneg i32 %8937 to i64, !dbg !338 + %8939 = or disjoint i64 %8938, 4611686293372403712, !dbg !338 + %8940 = add i32 %8884, 64, !dbg !338 + %8941 = lshr exact i32 %8940, 4, !dbg !338 + %8942 = and i32 %8941, 16383, !dbg !338 + %8943 = zext nneg i32 %8942 to i64, !dbg !338 + %8944 = or disjoint i64 %8943, 4611686293338849280, !dbg !338 + %8945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 0, !dbg !338 + %8946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 1, !dbg !338 + %8947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 2, !dbg !338 + %8948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 3, !dbg !338 + %8949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 4, !dbg !338 + %8950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 5, !dbg !338 + %8951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 6, !dbg !338 + %8952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 7, !dbg !338 + %8953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 8, !dbg !338 + %8954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 9, !dbg !338 + %8955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 10, !dbg !338 + %8956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 11, !dbg !338 + %8957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 12, !dbg !338 + %8958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 13, !dbg !338 + %8959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 14, !dbg !338 + %8960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 15, !dbg !338 + %8961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 16, !dbg !338 + %8962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 17, !dbg !338 + %8963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 18, !dbg !338 + %8964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 19, !dbg !338 + %8965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 20, !dbg !338 + %8966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 21, !dbg !338 + %8967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 22, !dbg !338 + %8968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 23, !dbg !338 + %8969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 24, !dbg !338 + %8970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 25, !dbg !338 + %8971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 26, !dbg !338 + %8972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 27, !dbg !338 + %8973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 28, !dbg !338 + %8974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 29, !dbg !338 + %8975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 30, !dbg !338 + %8976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8933, 31, !dbg !338 + %8977 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8945, float %8946, float %8947, float %8948, float %8949, float %8950, float %8951, float %8952, float %8953, float %8954, float %8955, float %8956, float %8957, float %8958, float %8959, float %8960, float %8961, float %8962, float %8963, float %8964, float %8965, float %8966, float %8967, float %8968, float %8969, float %8970, float %8971, float %8972, float %8973, float %8974, float %8975, float %8976, i64 %8939, i64 %8944, i1 true) #3, !dbg !338 + %8978 = or disjoint i32 %8878, 96, !dbg !338 + %8979 = add i32 %8978, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %8980 = lshr exact i32 %8979, 4, !dbg !338 + %8981 = and i32 %8980, 16383, !dbg !338 + %8982 = zext nneg i32 %8981 to i64, !dbg !338 + %8983 = or disjoint i64 %8982, 4611686293372403712, !dbg !338 + %8984 = add i32 %8884, 96, !dbg !338 + %8985 = lshr exact i32 %8984, 4, !dbg !338 + %8986 = and i32 %8985, 16383, !dbg !338 + %8987 = zext nneg i32 %8986 to i64, !dbg !338 + %8988 = or disjoint i64 %8987, 4611686293338849280, !dbg !338 + %8989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 0, !dbg !338 + %8990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 1, !dbg !338 + %8991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 2, !dbg !338 + %8992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 3, !dbg !338 + %8993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 4, !dbg !338 + %8994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 5, !dbg !338 + %8995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 6, !dbg !338 + %8996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 7, !dbg !338 + %8997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 8, !dbg !338 + %8998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 9, !dbg !338 + %8999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 10, !dbg !338 + %9000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 11, !dbg !338 + %9001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 12, !dbg !338 + %9002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 13, !dbg !338 + %9003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 14, !dbg !338 + %9004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 15, !dbg !338 + %9005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 16, !dbg !338 + %9006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 17, !dbg !338 + %9007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 18, !dbg !338 + %9008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 19, !dbg !338 + %9009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 20, !dbg !338 + %9010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 21, !dbg !338 + %9011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 22, !dbg !338 + %9012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 23, !dbg !338 + %9013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 24, !dbg !338 + %9014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 25, !dbg !338 + %9015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 26, !dbg !338 + %9016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 27, !dbg !338 + %9017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 28, !dbg !338 + %9018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 29, !dbg !338 + %9019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 30, !dbg !338 + %9020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8977, 31, !dbg !338 + %9021 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8989, float %8990, float %8991, float %8992, float %8993, float %8994, float %8995, float %8996, float %8997, float %8998, float %8999, float %9000, float %9001, float %9002, float %9003, float %9004, float %9005, float %9006, float %9007, float %9008, float %9009, float %9010, float %9011, float %9012, float %9013, float %9014, float %9015, float %9016, float %9017, float %9018, float %9019, float %9020, i64 %8983, i64 %8988, i1 true) #3, !dbg !338 + %9022 = or disjoint i32 %8878, 16384, !dbg !338 + %9023 = add i32 %9022, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %9024 = lshr exact i32 %9023, 4, !dbg !338 + %9025 = and i32 %9024, 16383, !dbg !338 + %9026 = zext nneg i32 %9025 to i64, !dbg !338 + %9027 = or disjoint i64 %9026, 4611686293372403712, !dbg !338 + %9028 = add i32 %8884, 8192, !dbg !338 + %9029 = lshr exact i32 %9028, 4, !dbg !338 + %9030 = and i32 %9029, 16383, !dbg !338 + %9031 = zext nneg i32 %9030 to i64, !dbg !338 + %9032 = or disjoint i64 %9031, 4611686293338849280, !dbg !338 + %9033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 0, !dbg !338 + %9034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 1, !dbg !338 + %9035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 2, !dbg !338 + %9036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 3, !dbg !338 + %9037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 4, !dbg !338 + %9038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 5, !dbg !338 + %9039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 6, !dbg !338 + %9040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 7, !dbg !338 + %9041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 8, !dbg !338 + %9042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 9, !dbg !338 + %9043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 10, !dbg !338 + %9044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 11, !dbg !338 + %9045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 12, !dbg !338 + %9046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 13, !dbg !338 + %9047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 14, !dbg !338 + %9048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 15, !dbg !338 + %9049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 16, !dbg !338 + %9050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 17, !dbg !338 + %9051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 18, !dbg !338 + %9052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 19, !dbg !338 + %9053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 20, !dbg !338 + %9054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 21, !dbg !338 + %9055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 22, !dbg !338 + %9056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 23, !dbg !338 + %9057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 24, !dbg !338 + %9058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 25, !dbg !338 + %9059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 26, !dbg !338 + %9060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 27, !dbg !338 + %9061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 28, !dbg !338 + %9062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 29, !dbg !338 + %9063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 30, !dbg !338 + %9064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9021, 31, !dbg !338 + %9065 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9033, float %9034, float %9035, float %9036, float %9037, float %9038, float %9039, float %9040, float %9041, float %9042, float %9043, float %9044, float %9045, float %9046, float %9047, float %9048, float %9049, float %9050, float %9051, float %9052, float %9053, float %9054, float %9055, float %9056, float %9057, float %9058, float %9059, float %9060, float %9061, float %9062, float %9063, float %9064, i64 %9027, i64 %9032, i1 true) #3, !dbg !338 + %9066 = or disjoint i32 %8878, 16416, !dbg !338 + %9067 = add i32 %9066, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %9068 = lshr exact i32 %9067, 4, !dbg !338 + %9069 = and i32 %9068, 16383, !dbg !338 + %9070 = zext nneg i32 %9069 to i64, !dbg !338 + %9071 = or disjoint i64 %9070, 4611686293372403712, !dbg !338 + %9072 = add i32 %8884, 8224, !dbg !338 + %9073 = lshr exact i32 %9072, 4, !dbg !338 + %9074 = and i32 %9073, 16383, !dbg !338 + %9075 = zext nneg i32 %9074 to i64, !dbg !338 + %9076 = or disjoint i64 %9075, 4611686293338849280, !dbg !338 + %9077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 0, !dbg !338 + %9078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 1, !dbg !338 + %9079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 2, !dbg !338 + %9080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 3, !dbg !338 + %9081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 4, !dbg !338 + %9082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 5, !dbg !338 + %9083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 6, !dbg !338 + %9084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 7, !dbg !338 + %9085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 8, !dbg !338 + %9086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 9, !dbg !338 + %9087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 10, !dbg !338 + %9088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 11, !dbg !338 + %9089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 12, !dbg !338 + %9090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 13, !dbg !338 + %9091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 14, !dbg !338 + %9092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 15, !dbg !338 + %9093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 16, !dbg !338 + %9094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 17, !dbg !338 + %9095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 18, !dbg !338 + %9096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 19, !dbg !338 + %9097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 20, !dbg !338 + %9098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 21, !dbg !338 + %9099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 22, !dbg !338 + %9100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 23, !dbg !338 + %9101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 24, !dbg !338 + %9102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 25, !dbg !338 + %9103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 26, !dbg !338 + %9104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 27, !dbg !338 + %9105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 28, !dbg !338 + %9106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 29, !dbg !338 + %9107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 30, !dbg !338 + %9108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9065, 31, !dbg !338 + %9109 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9077, float %9078, float %9079, float %9080, float %9081, float %9082, float %9083, float %9084, float %9085, float %9086, float %9087, float %9088, float %9089, float %9090, float %9091, float %9092, float %9093, float %9094, float %9095, float %9096, float %9097, float %9098, float %9099, float %9100, float %9101, float %9102, float %9103, float %9104, float %9105, float %9106, float %9107, float %9108, i64 %9071, i64 %9076, i1 true) #3, !dbg !338 + %9110 = or disjoint i32 %8878, 16448, !dbg !338 + %9111 = add i32 %9110, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %9112 = lshr exact i32 %9111, 4, !dbg !338 + %9113 = and i32 %9112, 16383, !dbg !338 + %9114 = zext nneg i32 %9113 to i64, !dbg !338 + %9115 = or disjoint i64 %9114, 4611686293372403712, !dbg !338 + %9116 = add i32 %8884, 8256, !dbg !338 + %9117 = lshr exact i32 %9116, 4, !dbg !338 + %9118 = and i32 %9117, 16383, !dbg !338 + %9119 = zext nneg i32 %9118 to i64, !dbg !338 + %9120 = or disjoint i64 %9119, 4611686293338849280, !dbg !338 + %9121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 0, !dbg !338 + %9122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 1, !dbg !338 + %9123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 2, !dbg !338 + %9124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 3, !dbg !338 + %9125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 4, !dbg !338 + %9126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 5, !dbg !338 + %9127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 6, !dbg !338 + %9128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 7, !dbg !338 + %9129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 8, !dbg !338 + %9130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 9, !dbg !338 + %9131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 10, !dbg !338 + %9132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 11, !dbg !338 + %9133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 12, !dbg !338 + %9134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 13, !dbg !338 + %9135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 14, !dbg !338 + %9136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 15, !dbg !338 + %9137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 16, !dbg !338 + %9138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 17, !dbg !338 + %9139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 18, !dbg !338 + %9140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 19, !dbg !338 + %9141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 20, !dbg !338 + %9142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 21, !dbg !338 + %9143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 22, !dbg !338 + %9144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 23, !dbg !338 + %9145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 24, !dbg !338 + %9146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 25, !dbg !338 + %9147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 26, !dbg !338 + %9148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 27, !dbg !338 + %9149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 28, !dbg !338 + %9150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 29, !dbg !338 + %9151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 30, !dbg !338 + %9152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9109, 31, !dbg !338 + %9153 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9121, float %9122, float %9123, float %9124, float %9125, float %9126, float %9127, float %9128, float %9129, float %9130, float %9131, float %9132, float %9133, float %9134, float %9135, float %9136, float %9137, float %9138, float %9139, float %9140, float %9141, float %9142, float %9143, float %9144, float %9145, float %9146, float %9147, float %9148, float %9149, float %9150, float %9151, float %9152, i64 %9115, i64 %9120, i1 true) #3, !dbg !338 + %9154 = or disjoint i32 %8878, 16480, !dbg !338 + %9155 = add i32 %9154, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !338 + %9156 = lshr exact i32 %9155, 4, !dbg !338 + %9157 = and i32 %9156, 16383, !dbg !338 + %9158 = zext nneg i32 %9157 to i64, !dbg !338 + %9159 = or disjoint i64 %9158, 4611686293372403712, !dbg !338 + %9160 = add i32 %8884, 8288, !dbg !338 + %9161 = lshr exact i32 %9160, 4, !dbg !338 + %9162 = and i32 %9161, 16383, !dbg !338 + %9163 = zext nneg i32 %9162 to i64, !dbg !338 + %9164 = or disjoint i64 %9163, 4611686293338849280, !dbg !338 + %9165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 0, !dbg !338 + %9166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 1, !dbg !338 + %9167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 2, !dbg !338 + %9168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 3, !dbg !338 + %9169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 4, !dbg !338 + %9170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 5, !dbg !338 + %9171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 6, !dbg !338 + %9172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 7, !dbg !338 + %9173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 8, !dbg !338 + %9174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 9, !dbg !338 + %9175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 10, !dbg !338 + %9176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 11, !dbg !338 + %9177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 12, !dbg !338 + %9178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 13, !dbg !338 + %9179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 14, !dbg !338 + %9180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 15, !dbg !338 + %9181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 16, !dbg !338 + %9182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 17, !dbg !338 + %9183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 18, !dbg !338 + %9184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 19, !dbg !338 + %9185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 20, !dbg !338 + %9186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 21, !dbg !338 + %9187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 22, !dbg !338 + %9188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 23, !dbg !338 + %9189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 24, !dbg !338 + %9190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 25, !dbg !338 + %9191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 26, !dbg !338 + %9192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 27, !dbg !338 + %9193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 28, !dbg !338 + %9194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 29, !dbg !338 + %9195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 30, !dbg !338 + %9196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9153, 31, !dbg !338 + %9197 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9165, float %9166, float %9167, float %9168, float %9169, float %9170, float %9171, float %9172, float %9173, float %9174, float %9175, float %9176, float %9177, float %9178, float %9179, float %9180, float %9181, float %9182, float %9183, float %9184, float %9185, float %9186, float %9187, float %9188, float %9189, float %9190, float %9191, float %9192, float %9193, float %9194, float %9195, float %9196, i64 %9159, i64 %9164, i1 true) #3, !dbg !338 + %9198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 0, !dbg !338 + %9199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 1, !dbg !338 + %9200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 2, !dbg !338 + %9201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 3, !dbg !338 + %9202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 4, !dbg !338 + %9203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 5, !dbg !338 + %9204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 6, !dbg !338 + %9205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 7, !dbg !338 + %9206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 8, !dbg !338 + %9207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 9, !dbg !338 + %9208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 10, !dbg !338 + %9209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 11, !dbg !338 + %9210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 12, !dbg !338 + %9211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 13, !dbg !338 + %9212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 14, !dbg !338 + %9213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 15, !dbg !338 + %9214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 16, !dbg !338 + %9215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 17, !dbg !338 + %9216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 18, !dbg !338 + %9217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 19, !dbg !338 + %9218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 20, !dbg !338 + %9219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 21, !dbg !338 + %9220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 22, !dbg !338 + %9221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 23, !dbg !338 + %9222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 24, !dbg !338 + %9223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 25, !dbg !338 + %9224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 26, !dbg !338 + %9225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 27, !dbg !338 + %9226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 28, !dbg !338 + %9227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 29, !dbg !338 + %9228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 30, !dbg !338 + %9229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9197, 31, !dbg !338 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !338 + %9230 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %9198, float %9199, float %9200, float %9201, float %9202, float %9203, float %9204, float %9205, float %9206, float %9207, float %9208, float %9209, float %9210, float %9211, float %9212, float %9213, float %9214, float %9215, float %9216, float %9217, float %9218, float %9219, float %9220, float %9221, float %9222, float %9223, float %9224, float %9225, float %9226, float %9227, float %9228, float %9229, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %8809, i32 0, i32 0) #3, !dbg !338 + %9231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 0, !dbg !338 + %9232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 1, !dbg !338 + %9233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 2, !dbg !338 + %9234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 3, !dbg !338 + %9235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 4, !dbg !338 + %9236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 5, !dbg !338 + %9237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 6, !dbg !338 + %9238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 7, !dbg !338 + %9239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 8, !dbg !338 + %9240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 9, !dbg !338 + %9241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 10, !dbg !338 + %9242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 11, !dbg !338 + %9243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 12, !dbg !338 + %9244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 13, !dbg !338 + %9245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 14, !dbg !338 + %9246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 15, !dbg !338 + %9247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 16, !dbg !338 + %9248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 17, !dbg !338 + %9249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 18, !dbg !338 + %9250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 19, !dbg !338 + %9251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 20, !dbg !338 + %9252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 21, !dbg !338 + %9253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 22, !dbg !338 + %9254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 23, !dbg !338 + %9255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 24, !dbg !338 + %9256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 25, !dbg !338 + %9257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 26, !dbg !338 + %9258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 27, !dbg !338 + %9259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 28, !dbg !338 + %9260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 29, !dbg !338 + %9261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 30, !dbg !338 + %9262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9230, 31, !dbg !338 + %9263 = fmul float %9231, 0x3FB6A09E60000000, !dbg !339 + %9264 = fmul float %9232, 0x3FB6A09E60000000, !dbg !339 + %9265 = fmul float %9233, 0x3FB6A09E60000000, !dbg !339 + %9266 = fmul float %9234, 0x3FB6A09E60000000, !dbg !339 + %9267 = fmul float %9235, 0x3FB6A09E60000000, !dbg !339 + %9268 = fmul float %9236, 0x3FB6A09E60000000, !dbg !339 + %9269 = fmul float %9237, 0x3FB6A09E60000000, !dbg !339 + %9270 = fmul float %9238, 0x3FB6A09E60000000, !dbg !339 + %9271 = fmul float %9239, 0x3FB6A09E60000000, !dbg !339 + %9272 = fmul float %9240, 0x3FB6A09E60000000, !dbg !339 + %9273 = fmul float %9241, 0x3FB6A09E60000000, !dbg !339 + %9274 = fmul float %9242, 0x3FB6A09E60000000, !dbg !339 + %9275 = fmul float %9243, 0x3FB6A09E60000000, !dbg !339 + %9276 = fmul float %9244, 0x3FB6A09E60000000, !dbg !339 + %9277 = fmul float %9245, 0x3FB6A09E60000000, !dbg !339 + %9278 = fmul float %9246, 0x3FB6A09E60000000, !dbg !339 + %9279 = fmul float %9247, 0x3FB6A09E60000000, !dbg !339 + %9280 = fmul float %9248, 0x3FB6A09E60000000, !dbg !339 + %9281 = fmul float %9249, 0x3FB6A09E60000000, !dbg !339 + %9282 = fmul float %9250, 0x3FB6A09E60000000, !dbg !339 + %9283 = fmul float %9251, 0x3FB6A09E60000000, !dbg !339 + %9284 = fmul float %9252, 0x3FB6A09E60000000, !dbg !339 + %9285 = fmul float %9253, 0x3FB6A09E60000000, !dbg !339 + %9286 = fmul float %9254, 0x3FB6A09E60000000, !dbg !339 + %9287 = fmul float %9255, 0x3FB6A09E60000000, !dbg !339 + %9288 = fmul float %9256, 0x3FB6A09E60000000, !dbg !339 + %9289 = fmul float %9257, 0x3FB6A09E60000000, !dbg !339 + %9290 = fmul float %9258, 0x3FB6A09E60000000, !dbg !339 + %9291 = fmul float %9259, 0x3FB6A09E60000000, !dbg !339 + %9292 = fmul float %9260, 0x3FB6A09E60000000, !dbg !339 + %9293 = fmul float %9261, 0x3FB6A09E60000000, !dbg !339 + %9294 = fmul float %9262, 0x3FB6A09E60000000, !dbg !339 + %9295 = fmul float %9263, 0x3FF7154760000000, !dbg !340 + %9296 = select i1 %8792, float %9295, float 0xFFF0000000000000, !dbg !341 + %9297 = fmul float %9264, 0x3FF7154760000000, !dbg !340 + %9298 = select i1 %8793, float %9297, float 0xFFF0000000000000, !dbg !341 + %9299 = fmul float %9265, 0x3FF7154760000000, !dbg !340 + %9300 = select i1 %8792, float %9299, float 0xFFF0000000000000, !dbg !341 + %9301 = fmul float %9266, 0x3FF7154760000000, !dbg !340 + %9302 = select i1 %8793, float %9301, float 0xFFF0000000000000, !dbg !341 + %9303 = fmul float %9267, 0x3FF7154760000000, !dbg !340 + %9304 = select i1 %8794, float %9303, float 0xFFF0000000000000, !dbg !341 + %9305 = fmul float %9268, 0x3FF7154760000000, !dbg !340 + %9306 = select i1 %8795, float %9305, float 0xFFF0000000000000, !dbg !341 + %9307 = fmul float %9269, 0x3FF7154760000000, !dbg !340 + %9308 = select i1 %8794, float %9307, float 0xFFF0000000000000, !dbg !341 + %9309 = fmul float %9270, 0x3FF7154760000000, !dbg !340 + %9310 = select i1 %8795, float %9309, float 0xFFF0000000000000, !dbg !341 + %9311 = fmul float %9271, 0x3FF7154760000000, !dbg !340 + %9312 = select i1 %8796, float %9311, float 0xFFF0000000000000, !dbg !341 + %9313 = fmul float %9272, 0x3FF7154760000000, !dbg !340 + %9314 = select i1 %8797, float %9313, float 0xFFF0000000000000, !dbg !341 + %9315 = fmul float %9273, 0x3FF7154760000000, !dbg !340 + %9316 = select i1 %8796, float %9315, float 0xFFF0000000000000, !dbg !341 + %9317 = fmul float %9274, 0x3FF7154760000000, !dbg !340 + %9318 = select i1 %8797, float %9317, float 0xFFF0000000000000, !dbg !341 + %9319 = fmul float %9275, 0x3FF7154760000000, !dbg !340 + %9320 = select i1 %8798, float %9319, float 0xFFF0000000000000, !dbg !341 + %9321 = fmul float %9276, 0x3FF7154760000000, !dbg !340 + %9322 = select i1 %8799, float %9321, float 0xFFF0000000000000, !dbg !341 + %9323 = fmul float %9277, 0x3FF7154760000000, !dbg !340 + %9324 = select i1 %8798, float %9323, float 0xFFF0000000000000, !dbg !341 + %9325 = fmul float %9278, 0x3FF7154760000000, !dbg !340 + %9326 = select i1 %8799, float %9325, float 0xFFF0000000000000, !dbg !341 + %9327 = fmul float %9279, 0x3FF7154760000000, !dbg !340 + %9328 = select i1 %8800, float %9327, float 0xFFF0000000000000, !dbg !341 + %9329 = fmul float %9280, 0x3FF7154760000000, !dbg !340 + %9330 = select i1 %8801, float %9329, float 0xFFF0000000000000, !dbg !341 + %9331 = fmul float %9281, 0x3FF7154760000000, !dbg !340 + %9332 = select i1 %8800, float %9331, float 0xFFF0000000000000, !dbg !341 + %9333 = fmul float %9282, 0x3FF7154760000000, !dbg !340 + %9334 = select i1 %8801, float %9333, float 0xFFF0000000000000, !dbg !341 + %9335 = fmul float %9283, 0x3FF7154760000000, !dbg !340 + %9336 = select i1 %8802, float %9335, float 0xFFF0000000000000, !dbg !341 + %9337 = fmul float %9284, 0x3FF7154760000000, !dbg !340 + %9338 = select i1 %8803, float %9337, float 0xFFF0000000000000, !dbg !341 + %9339 = fmul float %9285, 0x3FF7154760000000, !dbg !340 + %9340 = select i1 %8802, float %9339, float 0xFFF0000000000000, !dbg !341 + %9341 = fmul float %9286, 0x3FF7154760000000, !dbg !340 + %9342 = select i1 %8803, float %9341, float 0xFFF0000000000000, !dbg !341 + %9343 = fmul float %9287, 0x3FF7154760000000, !dbg !340 + %9344 = select i1 %8804, float %9343, float 0xFFF0000000000000, !dbg !341 + %9345 = fmul float %9288, 0x3FF7154760000000, !dbg !340 + %9346 = select i1 %8805, float %9345, float 0xFFF0000000000000, !dbg !341 + %9347 = fmul float %9289, 0x3FF7154760000000, !dbg !340 + %9348 = select i1 %8804, float %9347, float 0xFFF0000000000000, !dbg !341 + %9349 = fmul float %9290, 0x3FF7154760000000, !dbg !340 + %9350 = select i1 %8805, float %9349, float 0xFFF0000000000000, !dbg !341 + %9351 = fmul float %9291, 0x3FF7154760000000, !dbg !340 + %9352 = select i1 %8806, float %9351, float 0xFFF0000000000000, !dbg !341 + %9353 = fmul float %9292, 0x3FF7154760000000, !dbg !340 + %9354 = select i1 %8807, float %9353, float 0xFFF0000000000000, !dbg !341 + %9355 = fmul float %9293, 0x3FF7154760000000, !dbg !340 + %9356 = select i1 %8806, float %9355, float 0xFFF0000000000000, !dbg !341 + %9357 = fmul float %9294, 0x3FF7154760000000, !dbg !340 + %9358 = select i1 %8807, float %9357, float 0xFFF0000000000000, !dbg !341 + %9359 = fsub float %9296, %8860, !dbg !342 + %9360 = fsub float %9298, %8861, !dbg !342 + %9361 = fsub float %9300, %8860, !dbg !342 + %9362 = fsub float %9302, %8861, !dbg !342 + %9363 = fsub float %9304, %8862, !dbg !342 + %9364 = fsub float %9306, %8863, !dbg !342 + %9365 = fsub float %9308, %8862, !dbg !342 + %9366 = fsub float %9310, %8863, !dbg !342 + %9367 = fsub float %9312, %8864, !dbg !342 + %9368 = fsub float %9314, %8865, !dbg !342 + %9369 = fsub float %9316, %8864, !dbg !342 + %9370 = fsub float %9318, %8865, !dbg !342 + %9371 = fsub float %9320, %8866, !dbg !342 + %9372 = fsub float %9322, %8867, !dbg !342 + %9373 = fsub float %9324, %8866, !dbg !342 + %9374 = fsub float %9326, %8867, !dbg !342 + %9375 = fsub float %9328, %8868, !dbg !342 + %9376 = fsub float %9330, %8869, !dbg !342 + %9377 = fsub float %9332, %8868, !dbg !342 + %9378 = fsub float %9334, %8869, !dbg !342 + %9379 = fsub float %9336, %8870, !dbg !342 + %9380 = fsub float %9338, %8871, !dbg !342 + %9381 = fsub float %9340, %8870, !dbg !342 + %9382 = fsub float %9342, %8871, !dbg !342 + %9383 = fsub float %9344, %8872, !dbg !342 + %9384 = fsub float %9346, %8873, !dbg !342 + %9385 = fsub float %9348, %8872, !dbg !342 + %9386 = fsub float %9350, %8873, !dbg !342 + %9387 = fsub float %9352, %8874, !dbg !342 + %9388 = fsub float %9354, %8875, !dbg !342 + %9389 = fsub float %9356, %8874, !dbg !342 + %9390 = fsub float %9358, %8875, !dbg !342 + %9391 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i = icmp eq i32 %9391, 0, !dbg !343 + br i1 %.not.i, label %9394, label %9392, !dbg !343 + +9392: ; preds = %.lr.ph1873 + %9393 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9359) #3, !dbg !343 + br label %__nv_exp2f.exit, !dbg !343 + +9394: ; preds = %.lr.ph1873 + %9395 = tail call float @llvm.nvvm.ex2.approx.f(float %9359) #3, !dbg !343 + br label %__nv_exp2f.exit, !dbg !343 + +__nv_exp2f.exit: ; preds = %9392, %9394 + %.0.i = phi float [ %9393, %9392 ], [ %9395, %9394 ], !dbg !343 + %9396 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1229 = icmp eq i32 %9396, 0, !dbg !343 + br i1 %.not.i1229, label %9399, label %9397, !dbg !343 + +9397: ; preds = %__nv_exp2f.exit + %9398 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9360) #3, !dbg !343 + br label %__nv_exp2f.exit1231, !dbg !343 + +9399: ; preds = %__nv_exp2f.exit + %9400 = tail call float @llvm.nvvm.ex2.approx.f(float %9360) #3, !dbg !343 + br label %__nv_exp2f.exit1231, !dbg !343 + +__nv_exp2f.exit1231: ; preds = %9397, %9399 + %.0.i1230 = phi float [ %9398, %9397 ], [ %9400, %9399 ], !dbg !343 + %9401 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1232 = icmp eq i32 %9401, 0, !dbg !343 + br i1 %.not.i1232, label %9404, label %9402, !dbg !343 + +9402: ; preds = %__nv_exp2f.exit1231 + %9403 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9361) #3, !dbg !343 + br label %__nv_exp2f.exit1234, !dbg !343 + +9404: ; preds = %__nv_exp2f.exit1231 + %9405 = tail call float @llvm.nvvm.ex2.approx.f(float %9361) #3, !dbg !343 + br label %__nv_exp2f.exit1234, !dbg !343 + +__nv_exp2f.exit1234: ; preds = %9402, %9404 + %.0.i1233 = phi float [ %9403, %9402 ], [ %9405, %9404 ], !dbg !343 + %9406 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1235 = icmp eq i32 %9406, 0, !dbg !343 + br i1 %.not.i1235, label %9409, label %9407, !dbg !343 + +9407: ; preds = %__nv_exp2f.exit1234 + %9408 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9362) #3, !dbg !343 + br label %__nv_exp2f.exit1237, !dbg !343 + +9409: ; preds = %__nv_exp2f.exit1234 + %9410 = tail call float @llvm.nvvm.ex2.approx.f(float %9362) #3, !dbg !343 + br label %__nv_exp2f.exit1237, !dbg !343 + +__nv_exp2f.exit1237: ; preds = %9407, %9409 + %.0.i1236 = phi float [ %9408, %9407 ], [ %9410, %9409 ], !dbg !343 + %9411 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1238 = icmp eq i32 %9411, 0, !dbg !343 + br i1 %.not.i1238, label %9414, label %9412, !dbg !343 + +9412: ; preds = %__nv_exp2f.exit1237 + %9413 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9363) #3, !dbg !343 + br label %__nv_exp2f.exit1240, !dbg !343 + +9414: ; preds = %__nv_exp2f.exit1237 + %9415 = tail call float @llvm.nvvm.ex2.approx.f(float %9363) #3, !dbg !343 + br label %__nv_exp2f.exit1240, !dbg !343 + +__nv_exp2f.exit1240: ; preds = %9412, %9414 + %.0.i1239 = phi float [ %9413, %9412 ], [ %9415, %9414 ], !dbg !343 + %9416 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1241 = icmp eq i32 %9416, 0, !dbg !343 + br i1 %.not.i1241, label %9419, label %9417, !dbg !343 + +9417: ; preds = %__nv_exp2f.exit1240 + %9418 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9364) #3, !dbg !343 + br label %__nv_exp2f.exit1243, !dbg !343 + +9419: ; preds = %__nv_exp2f.exit1240 + %9420 = tail call float @llvm.nvvm.ex2.approx.f(float %9364) #3, !dbg !343 + br label %__nv_exp2f.exit1243, !dbg !343 + +__nv_exp2f.exit1243: ; preds = %9417, %9419 + %.0.i1242 = phi float [ %9418, %9417 ], [ %9420, %9419 ], !dbg !343 + %9421 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1244 = icmp eq i32 %9421, 0, !dbg !343 + br i1 %.not.i1244, label %9424, label %9422, !dbg !343 + +9422: ; preds = %__nv_exp2f.exit1243 + %9423 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9365) #3, !dbg !343 + br label %__nv_exp2f.exit1246, !dbg !343 + +9424: ; preds = %__nv_exp2f.exit1243 + %9425 = tail call float @llvm.nvvm.ex2.approx.f(float %9365) #3, !dbg !343 + br label %__nv_exp2f.exit1246, !dbg !343 + +__nv_exp2f.exit1246: ; preds = %9422, %9424 + %.0.i1245 = phi float [ %9423, %9422 ], [ %9425, %9424 ], !dbg !343 + %9426 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1247 = icmp eq i32 %9426, 0, !dbg !343 + br i1 %.not.i1247, label %9429, label %9427, !dbg !343 + +9427: ; preds = %__nv_exp2f.exit1246 + %9428 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9366) #3, !dbg !343 + br label %__nv_exp2f.exit1249, !dbg !343 + +9429: ; preds = %__nv_exp2f.exit1246 + %9430 = tail call float @llvm.nvvm.ex2.approx.f(float %9366) #3, !dbg !343 + br label %__nv_exp2f.exit1249, !dbg !343 + +__nv_exp2f.exit1249: ; preds = %9427, %9429 + %.0.i1248 = phi float [ %9428, %9427 ], [ %9430, %9429 ], !dbg !343 + %9431 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1250 = icmp eq i32 %9431, 0, !dbg !343 + br i1 %.not.i1250, label %9434, label %9432, !dbg !343 + +9432: ; preds = %__nv_exp2f.exit1249 + %9433 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9367) #3, !dbg !343 + br label %__nv_exp2f.exit1252, !dbg !343 + +9434: ; preds = %__nv_exp2f.exit1249 + %9435 = tail call float @llvm.nvvm.ex2.approx.f(float %9367) #3, !dbg !343 + br label %__nv_exp2f.exit1252, !dbg !343 + +__nv_exp2f.exit1252: ; preds = %9432, %9434 + %.0.i1251 = phi float [ %9433, %9432 ], [ %9435, %9434 ], !dbg !343 + %9436 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1253 = icmp eq i32 %9436, 0, !dbg !343 + br i1 %.not.i1253, label %9439, label %9437, !dbg !343 + +9437: ; preds = %__nv_exp2f.exit1252 + %9438 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9368) #3, !dbg !343 + br label %__nv_exp2f.exit1255, !dbg !343 + +9439: ; preds = %__nv_exp2f.exit1252 + %9440 = tail call float @llvm.nvvm.ex2.approx.f(float %9368) #3, !dbg !343 + br label %__nv_exp2f.exit1255, !dbg !343 + +__nv_exp2f.exit1255: ; preds = %9437, %9439 + %.0.i1254 = phi float [ %9438, %9437 ], [ %9440, %9439 ], !dbg !343 + %9441 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1256 = icmp eq i32 %9441, 0, !dbg !343 + br i1 %.not.i1256, label %9444, label %9442, !dbg !343 + +9442: ; preds = %__nv_exp2f.exit1255 + %9443 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9369) #3, !dbg !343 + br label %__nv_exp2f.exit1258, !dbg !343 + +9444: ; preds = %__nv_exp2f.exit1255 + %9445 = tail call float @llvm.nvvm.ex2.approx.f(float %9369) #3, !dbg !343 + br label %__nv_exp2f.exit1258, !dbg !343 + +__nv_exp2f.exit1258: ; preds = %9442, %9444 + %.0.i1257 = phi float [ %9443, %9442 ], [ %9445, %9444 ], !dbg !343 + %9446 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1259 = icmp eq i32 %9446, 0, !dbg !343 + br i1 %.not.i1259, label %9449, label %9447, !dbg !343 + +9447: ; preds = %__nv_exp2f.exit1258 + %9448 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9370) #3, !dbg !343 + br label %__nv_exp2f.exit1261, !dbg !343 + +9449: ; preds = %__nv_exp2f.exit1258 + %9450 = tail call float @llvm.nvvm.ex2.approx.f(float %9370) #3, !dbg !343 + br label %__nv_exp2f.exit1261, !dbg !343 + +__nv_exp2f.exit1261: ; preds = %9447, %9449 + %.0.i1260 = phi float [ %9448, %9447 ], [ %9450, %9449 ], !dbg !343 + %9451 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1262 = icmp eq i32 %9451, 0, !dbg !343 + br i1 %.not.i1262, label %9454, label %9452, !dbg !343 + +9452: ; preds = %__nv_exp2f.exit1261 + %9453 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9371) #3, !dbg !343 + br label %__nv_exp2f.exit1264, !dbg !343 + +9454: ; preds = %__nv_exp2f.exit1261 + %9455 = tail call float @llvm.nvvm.ex2.approx.f(float %9371) #3, !dbg !343 + br label %__nv_exp2f.exit1264, !dbg !343 + +__nv_exp2f.exit1264: ; preds = %9452, %9454 + %.0.i1263 = phi float [ %9453, %9452 ], [ %9455, %9454 ], !dbg !343 + %9456 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1265 = icmp eq i32 %9456, 0, !dbg !343 + br i1 %.not.i1265, label %9459, label %9457, !dbg !343 + +9457: ; preds = %__nv_exp2f.exit1264 + %9458 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9372) #3, !dbg !343 + br label %__nv_exp2f.exit1267, !dbg !343 + +9459: ; preds = %__nv_exp2f.exit1264 + %9460 = tail call float @llvm.nvvm.ex2.approx.f(float %9372) #3, !dbg !343 + br label %__nv_exp2f.exit1267, !dbg !343 + +__nv_exp2f.exit1267: ; preds = %9457, %9459 + %.0.i1266 = phi float [ %9458, %9457 ], [ %9460, %9459 ], !dbg !343 + %9461 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1268 = icmp eq i32 %9461, 0, !dbg !343 + br i1 %.not.i1268, label %9464, label %9462, !dbg !343 + +9462: ; preds = %__nv_exp2f.exit1267 + %9463 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9373) #3, !dbg !343 + br label %__nv_exp2f.exit1270, !dbg !343 + +9464: ; preds = %__nv_exp2f.exit1267 + %9465 = tail call float @llvm.nvvm.ex2.approx.f(float %9373) #3, !dbg !343 + br label %__nv_exp2f.exit1270, !dbg !343 + +__nv_exp2f.exit1270: ; preds = %9462, %9464 + %.0.i1269 = phi float [ %9463, %9462 ], [ %9465, %9464 ], !dbg !343 + %9466 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1271 = icmp eq i32 %9466, 0, !dbg !343 + br i1 %.not.i1271, label %9469, label %9467, !dbg !343 + +9467: ; preds = %__nv_exp2f.exit1270 + %9468 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9374) #3, !dbg !343 + br label %__nv_exp2f.exit1273, !dbg !343 + +9469: ; preds = %__nv_exp2f.exit1270 + %9470 = tail call float @llvm.nvvm.ex2.approx.f(float %9374) #3, !dbg !343 + br label %__nv_exp2f.exit1273, !dbg !343 + +__nv_exp2f.exit1273: ; preds = %9467, %9469 + %.0.i1272 = phi float [ %9468, %9467 ], [ %9470, %9469 ], !dbg !343 + %9471 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1274 = icmp eq i32 %9471, 0, !dbg !343 + br i1 %.not.i1274, label %9474, label %9472, !dbg !343 + +9472: ; preds = %__nv_exp2f.exit1273 + %9473 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9375) #3, !dbg !343 + br label %__nv_exp2f.exit1276, !dbg !343 + +9474: ; preds = %__nv_exp2f.exit1273 + %9475 = tail call float @llvm.nvvm.ex2.approx.f(float %9375) #3, !dbg !343 + br label %__nv_exp2f.exit1276, !dbg !343 + +__nv_exp2f.exit1276: ; preds = %9472, %9474 + %.0.i1275 = phi float [ %9473, %9472 ], [ %9475, %9474 ], !dbg !343 + %9476 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1277 = icmp eq i32 %9476, 0, !dbg !343 + br i1 %.not.i1277, label %9479, label %9477, !dbg !343 + +9477: ; preds = %__nv_exp2f.exit1276 + %9478 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9376) #3, !dbg !343 + br label %__nv_exp2f.exit1279, !dbg !343 + +9479: ; preds = %__nv_exp2f.exit1276 + %9480 = tail call float @llvm.nvvm.ex2.approx.f(float %9376) #3, !dbg !343 + br label %__nv_exp2f.exit1279, !dbg !343 + +__nv_exp2f.exit1279: ; preds = %9477, %9479 + %.0.i1278 = phi float [ %9478, %9477 ], [ %9480, %9479 ], !dbg !343 + %9481 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1280 = icmp eq i32 %9481, 0, !dbg !343 + br i1 %.not.i1280, label %9484, label %9482, !dbg !343 + +9482: ; preds = %__nv_exp2f.exit1279 + %9483 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9377) #3, !dbg !343 + br label %__nv_exp2f.exit1282, !dbg !343 + +9484: ; preds = %__nv_exp2f.exit1279 + %9485 = tail call float @llvm.nvvm.ex2.approx.f(float %9377) #3, !dbg !343 + br label %__nv_exp2f.exit1282, !dbg !343 + +__nv_exp2f.exit1282: ; preds = %9482, %9484 + %.0.i1281 = phi float [ %9483, %9482 ], [ %9485, %9484 ], !dbg !343 + %9486 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1283 = icmp eq i32 %9486, 0, !dbg !343 + br i1 %.not.i1283, label %9489, label %9487, !dbg !343 + +9487: ; preds = %__nv_exp2f.exit1282 + %9488 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9378) #3, !dbg !343 + br label %__nv_exp2f.exit1285, !dbg !343 + +9489: ; preds = %__nv_exp2f.exit1282 + %9490 = tail call float @llvm.nvvm.ex2.approx.f(float %9378) #3, !dbg !343 + br label %__nv_exp2f.exit1285, !dbg !343 + +__nv_exp2f.exit1285: ; preds = %9487, %9489 + %.0.i1284 = phi float [ %9488, %9487 ], [ %9490, %9489 ], !dbg !343 + %9491 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1286 = icmp eq i32 %9491, 0, !dbg !343 + br i1 %.not.i1286, label %9494, label %9492, !dbg !343 + +9492: ; preds = %__nv_exp2f.exit1285 + %9493 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9379) #3, !dbg !343 + br label %__nv_exp2f.exit1288, !dbg !343 + +9494: ; preds = %__nv_exp2f.exit1285 + %9495 = tail call float @llvm.nvvm.ex2.approx.f(float %9379) #3, !dbg !343 + br label %__nv_exp2f.exit1288, !dbg !343 + +__nv_exp2f.exit1288: ; preds = %9492, %9494 + %.0.i1287 = phi float [ %9493, %9492 ], [ %9495, %9494 ], !dbg !343 + %9496 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1289 = icmp eq i32 %9496, 0, !dbg !343 + br i1 %.not.i1289, label %9499, label %9497, !dbg !343 + +9497: ; preds = %__nv_exp2f.exit1288 + %9498 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9380) #3, !dbg !343 + br label %__nv_exp2f.exit1291, !dbg !343 + +9499: ; preds = %__nv_exp2f.exit1288 + %9500 = tail call float @llvm.nvvm.ex2.approx.f(float %9380) #3, !dbg !343 + br label %__nv_exp2f.exit1291, !dbg !343 + +__nv_exp2f.exit1291: ; preds = %9497, %9499 + %.0.i1290 = phi float [ %9498, %9497 ], [ %9500, %9499 ], !dbg !343 + %9501 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1292 = icmp eq i32 %9501, 0, !dbg !343 + br i1 %.not.i1292, label %9504, label %9502, !dbg !343 + +9502: ; preds = %__nv_exp2f.exit1291 + %9503 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9381) #3, !dbg !343 + br label %__nv_exp2f.exit1294, !dbg !343 + +9504: ; preds = %__nv_exp2f.exit1291 + %9505 = tail call float @llvm.nvvm.ex2.approx.f(float %9381) #3, !dbg !343 + br label %__nv_exp2f.exit1294, !dbg !343 + +__nv_exp2f.exit1294: ; preds = %9502, %9504 + %.0.i1293 = phi float [ %9503, %9502 ], [ %9505, %9504 ], !dbg !343 + %9506 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1295 = icmp eq i32 %9506, 0, !dbg !343 + br i1 %.not.i1295, label %9509, label %9507, !dbg !343 + +9507: ; preds = %__nv_exp2f.exit1294 + %9508 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9382) #3, !dbg !343 + br label %__nv_exp2f.exit1297, !dbg !343 + +9509: ; preds = %__nv_exp2f.exit1294 + %9510 = tail call float @llvm.nvvm.ex2.approx.f(float %9382) #3, !dbg !343 + br label %__nv_exp2f.exit1297, !dbg !343 + +__nv_exp2f.exit1297: ; preds = %9507, %9509 + %.0.i1296 = phi float [ %9508, %9507 ], [ %9510, %9509 ], !dbg !343 + %9511 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1298 = icmp eq i32 %9511, 0, !dbg !343 + br i1 %.not.i1298, label %9514, label %9512, !dbg !343 + +9512: ; preds = %__nv_exp2f.exit1297 + %9513 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9383) #3, !dbg !343 + br label %__nv_exp2f.exit1300, !dbg !343 + +9514: ; preds = %__nv_exp2f.exit1297 + %9515 = tail call float @llvm.nvvm.ex2.approx.f(float %9383) #3, !dbg !343 + br label %__nv_exp2f.exit1300, !dbg !343 + +__nv_exp2f.exit1300: ; preds = %9512, %9514 + %.0.i1299 = phi float [ %9513, %9512 ], [ %9515, %9514 ], !dbg !343 + %9516 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1301 = icmp eq i32 %9516, 0, !dbg !343 + br i1 %.not.i1301, label %9519, label %9517, !dbg !343 + +9517: ; preds = %__nv_exp2f.exit1300 + %9518 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9384) #3, !dbg !343 + br label %__nv_exp2f.exit1303, !dbg !343 + +9519: ; preds = %__nv_exp2f.exit1300 + %9520 = tail call float @llvm.nvvm.ex2.approx.f(float %9384) #3, !dbg !343 + br label %__nv_exp2f.exit1303, !dbg !343 + +__nv_exp2f.exit1303: ; preds = %9517, %9519 + %.0.i1302 = phi float [ %9518, %9517 ], [ %9520, %9519 ], !dbg !343 + %9521 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1304 = icmp eq i32 %9521, 0, !dbg !343 + br i1 %.not.i1304, label %9524, label %9522, !dbg !343 + +9522: ; preds = %__nv_exp2f.exit1303 + %9523 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9385) #3, !dbg !343 + br label %__nv_exp2f.exit1306, !dbg !343 + +9524: ; preds = %__nv_exp2f.exit1303 + %9525 = tail call float @llvm.nvvm.ex2.approx.f(float %9385) #3, !dbg !343 + br label %__nv_exp2f.exit1306, !dbg !343 + +__nv_exp2f.exit1306: ; preds = %9522, %9524 + %.0.i1305 = phi float [ %9523, %9522 ], [ %9525, %9524 ], !dbg !343 + %9526 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1307 = icmp eq i32 %9526, 0, !dbg !343 + br i1 %.not.i1307, label %9529, label %9527, !dbg !343 + +9527: ; preds = %__nv_exp2f.exit1306 + %9528 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9386) #3, !dbg !343 + br label %__nv_exp2f.exit1309, !dbg !343 + +9529: ; preds = %__nv_exp2f.exit1306 + %9530 = tail call float @llvm.nvvm.ex2.approx.f(float %9386) #3, !dbg !343 + br label %__nv_exp2f.exit1309, !dbg !343 + +__nv_exp2f.exit1309: ; preds = %9527, %9529 + %.0.i1308 = phi float [ %9528, %9527 ], [ %9530, %9529 ], !dbg !343 + %9531 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1310 = icmp eq i32 %9531, 0, !dbg !343 + br i1 %.not.i1310, label %9534, label %9532, !dbg !343 + +9532: ; preds = %__nv_exp2f.exit1309 + %9533 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9387) #3, !dbg !343 + br label %__nv_exp2f.exit1312, !dbg !343 + +9534: ; preds = %__nv_exp2f.exit1309 + %9535 = tail call float @llvm.nvvm.ex2.approx.f(float %9387) #3, !dbg !343 + br label %__nv_exp2f.exit1312, !dbg !343 + +__nv_exp2f.exit1312: ; preds = %9532, %9534 + %.0.i1311 = phi float [ %9533, %9532 ], [ %9535, %9534 ], !dbg !343 + %9536 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1313 = icmp eq i32 %9536, 0, !dbg !343 + br i1 %.not.i1313, label %9539, label %9537, !dbg !343 + +9537: ; preds = %__nv_exp2f.exit1312 + %9538 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9388) #3, !dbg !343 + br label %__nv_exp2f.exit1315, !dbg !343 + +9539: ; preds = %__nv_exp2f.exit1312 + %9540 = tail call float @llvm.nvvm.ex2.approx.f(float %9388) #3, !dbg !343 + br label %__nv_exp2f.exit1315, !dbg !343 + +__nv_exp2f.exit1315: ; preds = %9537, %9539 + %.0.i1314 = phi float [ %9538, %9537 ], [ %9540, %9539 ], !dbg !343 + %9541 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1316 = icmp eq i32 %9541, 0, !dbg !343 + br i1 %.not.i1316, label %9544, label %9542, !dbg !343 + +9542: ; preds = %__nv_exp2f.exit1315 + %9543 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9389) #3, !dbg !343 + br label %__nv_exp2f.exit1318, !dbg !343 + +9544: ; preds = %__nv_exp2f.exit1315 + %9545 = tail call float @llvm.nvvm.ex2.approx.f(float %9389) #3, !dbg !343 + br label %__nv_exp2f.exit1318, !dbg !343 + +__nv_exp2f.exit1318: ; preds = %9542, %9544 + %.0.i1317 = phi float [ %9543, %9542 ], [ %9545, %9544 ], !dbg !343 + %9546 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !343 + %.not.i1319 = icmp eq i32 %9546, 0, !dbg !343 + br i1 %.not.i1319, label %9549, label %9547, !dbg !343 + +9547: ; preds = %__nv_exp2f.exit1318 + %9548 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9390) #3, !dbg !343 + br label %__nv_exp2f.exit1321, !dbg !343 + +9549: ; preds = %__nv_exp2f.exit1318 + %9550 = tail call float @llvm.nvvm.ex2.approx.f(float %9390) #3, !dbg !343 + br label %__nv_exp2f.exit1321, !dbg !343 + +__nv_exp2f.exit1321: ; preds = %9547, %9549 + %.0.i1320 = phi float [ %9548, %9547 ], [ %9550, %9549 ], !dbg !343 + %9551 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %8808, !dbg !329 + %9552 = insertelement <2 x float> poison, float %.0.i, i64 0, !dbg !344 + %9553 = insertelement <2 x float> %9552, float %.0.i1230, i64 1, !dbg !344 + %9554 = fptrunc <2 x float> %9553 to <2 x bfloat>, !dbg !344 + %9555 = insertelement <2 x float> poison, float %.0.i1233, i64 0, !dbg !344 + %9556 = insertelement <2 x float> %9555, float %.0.i1236, i64 1, !dbg !344 + %9557 = fptrunc <2 x float> %9556 to <2 x bfloat>, !dbg !344 + %9558 = insertelement <2 x float> poison, float %.0.i1239, i64 0, !dbg !344 + %9559 = insertelement <2 x float> %9558, float %.0.i1242, i64 1, !dbg !344 + %9560 = fptrunc <2 x float> %9559 to <2 x bfloat>, !dbg !344 + %9561 = insertelement <2 x float> poison, float %.0.i1245, i64 0, !dbg !344 + %9562 = insertelement <2 x float> %9561, float %.0.i1248, i64 1, !dbg !344 + %9563 = fptrunc <2 x float> %9562 to <2 x bfloat>, !dbg !344 + %9564 = insertelement <2 x float> poison, float %.0.i1251, i64 0, !dbg !344 + %9565 = insertelement <2 x float> %9564, float %.0.i1254, i64 1, !dbg !344 + %9566 = fptrunc <2 x float> %9565 to <2 x bfloat>, !dbg !344 + %9567 = insertelement <2 x float> poison, float %.0.i1257, i64 0, !dbg !344 + %9568 = insertelement <2 x float> %9567, float %.0.i1260, i64 1, !dbg !344 + %9569 = fptrunc <2 x float> %9568 to <2 x bfloat>, !dbg !344 + %9570 = insertelement <2 x float> poison, float %.0.i1263, i64 0, !dbg !344 + %9571 = insertelement <2 x float> %9570, float %.0.i1266, i64 1, !dbg !344 + %9572 = fptrunc <2 x float> %9571 to <2 x bfloat>, !dbg !344 + %9573 = insertelement <2 x float> poison, float %.0.i1269, i64 0, !dbg !344 + %9574 = insertelement <2 x float> %9573, float %.0.i1272, i64 1, !dbg !344 + %9575 = fptrunc <2 x float> %9574 to <2 x bfloat>, !dbg !344 + %9576 = insertelement <2 x float> poison, float %.0.i1275, i64 0, !dbg !344 + %9577 = insertelement <2 x float> %9576, float %.0.i1278, i64 1, !dbg !344 + %9578 = fptrunc <2 x float> %9577 to <2 x bfloat>, !dbg !344 + %9579 = insertelement <2 x float> poison, float %.0.i1281, i64 0, !dbg !344 + %9580 = insertelement <2 x float> %9579, float %.0.i1284, i64 1, !dbg !344 + %9581 = fptrunc <2 x float> %9580 to <2 x bfloat>, !dbg !344 + %9582 = insertelement <2 x float> poison, float %.0.i1287, i64 0, !dbg !344 + %9583 = insertelement <2 x float> %9582, float %.0.i1290, i64 1, !dbg !344 + %9584 = fptrunc <2 x float> %9583 to <2 x bfloat>, !dbg !344 + %9585 = insertelement <2 x float> poison, float %.0.i1293, i64 0, !dbg !344 + %9586 = insertelement <2 x float> %9585, float %.0.i1296, i64 1, !dbg !344 + %9587 = fptrunc <2 x float> %9586 to <2 x bfloat>, !dbg !344 + %9588 = insertelement <2 x float> poison, float %.0.i1299, i64 0, !dbg !344 + %9589 = insertelement <2 x float> %9588, float %.0.i1302, i64 1, !dbg !344 + %9590 = fptrunc <2 x float> %9589 to <2 x bfloat>, !dbg !344 + %9591 = insertelement <2 x float> poison, float %.0.i1305, i64 0, !dbg !344 + %9592 = insertelement <2 x float> %9591, float %.0.i1308, i64 1, !dbg !344 + %9593 = fptrunc <2 x float> %9592 to <2 x bfloat>, !dbg !344 + %9594 = insertelement <2 x float> poison, float %.0.i1311, i64 0, !dbg !344 + %9595 = insertelement <2 x float> %9594, float %.0.i1314, i64 1, !dbg !344 + %9596 = fptrunc <2 x float> %9595 to <2 x bfloat>, !dbg !344 + %9597 = insertelement <2 x float> poison, float %.0.i1317, i64 0, !dbg !344 + %9598 = insertelement <2 x float> %9597, float %.0.i1320, i64 1, !dbg !344 + %9599 = fptrunc <2 x float> %9598 to <2 x bfloat>, !dbg !344 + %9600 = bitcast <2 x bfloat> %9554 to i32, !dbg !345 + %9601 = bitcast <2 x bfloat> %9557 to i32, !dbg !345 + %9602 = bitcast <2 x bfloat> %9560 to i32, !dbg !345 + %9603 = bitcast <2 x bfloat> %9563 to i32, !dbg !345 + %9604 = bitcast <2 x bfloat> %9566 to i32, !dbg !345 + %9605 = bitcast <2 x bfloat> %9569 to i32, !dbg !345 + %9606 = bitcast <2 x bfloat> %9572 to i32, !dbg !345 + %9607 = bitcast <2 x bfloat> %9575 to i32, !dbg !345 + %9608 = bitcast <2 x bfloat> %9578 to i32, !dbg !345 + %9609 = bitcast <2 x bfloat> %9581 to i32, !dbg !345 + %9610 = bitcast <2 x bfloat> %9584 to i32, !dbg !345 + %9611 = bitcast <2 x bfloat> %9587 to i32, !dbg !345 + %9612 = bitcast <2 x bfloat> %9590 to i32, !dbg !345 + %9613 = bitcast <2 x bfloat> %9593 to i32, !dbg !345 + %9614 = bitcast <2 x bfloat> %9596 to i32, !dbg !345 + %9615 = bitcast <2 x bfloat> %9599 to i32, !dbg !345 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !345 + %9616 = ptrtoint ptr addrspace(3) %9551 to i32, !dbg !345 + %9617 = lshr exact i32 %9616, 4, !dbg !345 + %9618 = and i32 %9617, 16383, !dbg !345 + %9619 = zext nneg i32 %9618 to i64, !dbg !345 + %9620 = or disjoint i64 %9619, 4611686293338849280, !dbg !345 + %9621 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn4751768, float %.pn4731769, float %.pn4711770, float %.pn4691771, float %.pn4671772, float %.pn4651773, float %.pn4631774, float %.pn4611775, float %.pn4591776, float %.pn4571777, float %.pn4551778, float %.pn4531779, float %.pn4511780, float %.pn4491781, float %.pn4471782, float %.pn4451783, float %.pn4431784, float %.pn4411785, float %.pn4391786, float %.pn4371787, float %.pn4351788, float %.pn4331789, float %.pn4311790, float %.pn4291791, float %.pn4271792, float %.pn4251793, float %.pn4231794, float %.pn4211795, float %.pn4191796, float %.pn4171797, float %.pn4151798, float %.pn4131799, float %.pn4111800, float %.pn4091801, float %.pn4071802, float %.pn4051803, float %.pn4031804, float %.pn4011805, float %.pn3991806, float %.pn3971807, float %.pn3951808, float %.pn3931809, float %.pn3911810, float %.pn3891811, float %.pn3871812, float %.pn3851813, float %.pn3831814, float %.pn3811815, float %.pn3791816, float %.pn3771817, float %.pn3751818, float %.pn3731819, float %.pn3711820, float %.pn3691821, float %.pn3671822, float %.pn3651823, float %.pn3631824, float %.pn3611825, float %.pn3591826, float %.pn3571827, float %.pn3551828, float %.pn3531829, float %.pn3511830, float %.pn3491831, i32 %9600, i32 %9601, i32 %9602, i32 %9603, i64 %9620, i1 true) #3, !dbg !345 + %9622 = add i32 %9616, 2048, !dbg !345 + %9623 = lshr exact i32 %9622, 4, !dbg !345 + %9624 = and i32 %9623, 16383, !dbg !345 + %9625 = zext nneg i32 %9624 to i64, !dbg !345 + %9626 = or disjoint i64 %9625, 4611686293338849280, !dbg !345 + %9627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 0, !dbg !345 + %9628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 1, !dbg !345 + %9629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 2, !dbg !345 + %9630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 3, !dbg !345 + %9631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 4, !dbg !345 + %9632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 5, !dbg !345 + %9633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 6, !dbg !345 + %9634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 7, !dbg !345 + %9635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 8, !dbg !345 + %9636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 9, !dbg !345 + %9637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 10, !dbg !345 + %9638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 11, !dbg !345 + %9639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 12, !dbg !345 + %9640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 13, !dbg !345 + %9641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 14, !dbg !345 + %9642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 15, !dbg !345 + %9643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 16, !dbg !345 + %9644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 17, !dbg !345 + %9645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 18, !dbg !345 + %9646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 19, !dbg !345 + %9647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 20, !dbg !345 + %9648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 21, !dbg !345 + %9649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 22, !dbg !345 + %9650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 23, !dbg !345 + %9651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 24, !dbg !345 + %9652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 25, !dbg !345 + %9653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 26, !dbg !345 + %9654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 27, !dbg !345 + %9655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 28, !dbg !345 + %9656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 29, !dbg !345 + %9657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 30, !dbg !345 + %9658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 31, !dbg !345 + %9659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 32, !dbg !345 + %9660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 33, !dbg !345 + %9661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 34, !dbg !345 + %9662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 35, !dbg !345 + %9663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 36, !dbg !345 + %9664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 37, !dbg !345 + %9665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 38, !dbg !345 + %9666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 39, !dbg !345 + %9667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 40, !dbg !345 + %9668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 41, !dbg !345 + %9669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 42, !dbg !345 + %9670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 43, !dbg !345 + %9671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 44, !dbg !345 + %9672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 45, !dbg !345 + %9673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 46, !dbg !345 + %9674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 47, !dbg !345 + %9675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 48, !dbg !345 + %9676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 49, !dbg !345 + %9677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 50, !dbg !345 + %9678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 51, !dbg !345 + %9679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 52, !dbg !345 + %9680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 53, !dbg !345 + %9681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 54, !dbg !345 + %9682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 55, !dbg !345 + %9683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 56, !dbg !345 + %9684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 57, !dbg !345 + %9685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 58, !dbg !345 + %9686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 59, !dbg !345 + %9687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 60, !dbg !345 + %9688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 61, !dbg !345 + %9689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 62, !dbg !345 + %9690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9621, 63, !dbg !345 + %9691 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9627, float %9628, float %9629, float %9630, float %9631, float %9632, float %9633, float %9634, float %9635, float %9636, float %9637, float %9638, float %9639, float %9640, float %9641, float %9642, float %9643, float %9644, float %9645, float %9646, float %9647, float %9648, float %9649, float %9650, float %9651, float %9652, float %9653, float %9654, float %9655, float %9656, float %9657, float %9658, float %9659, float %9660, float %9661, float %9662, float %9663, float %9664, float %9665, float %9666, float %9667, float %9668, float %9669, float %9670, float %9671, float %9672, float %9673, float %9674, float %9675, float %9676, float %9677, float %9678, float %9679, float %9680, float %9681, float %9682, float %9683, float %9684, float %9685, float %9686, float %9687, float %9688, float %9689, float %9690, i32 %9604, i32 %9605, i32 %9606, i32 %9607, i64 %9626, i1 true) #3, !dbg !345 + %9692 = add i32 %9616, 4096, !dbg !345 + %9693 = lshr exact i32 %9692, 4, !dbg !345 + %9694 = and i32 %9693, 16383, !dbg !345 + %9695 = zext nneg i32 %9694 to i64, !dbg !345 + %9696 = or disjoint i64 %9695, 4611686293338849280, !dbg !345 + %9697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 0, !dbg !345 + %9698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 1, !dbg !345 + %9699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 2, !dbg !345 + %9700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 3, !dbg !345 + %9701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 4, !dbg !345 + %9702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 5, !dbg !345 + %9703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 6, !dbg !345 + %9704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 7, !dbg !345 + %9705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 8, !dbg !345 + %9706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 9, !dbg !345 + %9707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 10, !dbg !345 + %9708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 11, !dbg !345 + %9709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 12, !dbg !345 + %9710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 13, !dbg !345 + %9711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 14, !dbg !345 + %9712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 15, !dbg !345 + %9713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 16, !dbg !345 + %9714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 17, !dbg !345 + %9715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 18, !dbg !345 + %9716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 19, !dbg !345 + %9717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 20, !dbg !345 + %9718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 21, !dbg !345 + %9719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 22, !dbg !345 + %9720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 23, !dbg !345 + %9721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 24, !dbg !345 + %9722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 25, !dbg !345 + %9723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 26, !dbg !345 + %9724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 27, !dbg !345 + %9725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 28, !dbg !345 + %9726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 29, !dbg !345 + %9727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 30, !dbg !345 + %9728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 31, !dbg !345 + %9729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 32, !dbg !345 + %9730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 33, !dbg !345 + %9731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 34, !dbg !345 + %9732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 35, !dbg !345 + %9733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 36, !dbg !345 + %9734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 37, !dbg !345 + %9735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 38, !dbg !345 + %9736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 39, !dbg !345 + %9737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 40, !dbg !345 + %9738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 41, !dbg !345 + %9739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 42, !dbg !345 + %9740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 43, !dbg !345 + %9741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 44, !dbg !345 + %9742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 45, !dbg !345 + %9743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 46, !dbg !345 + %9744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 47, !dbg !345 + %9745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 48, !dbg !345 + %9746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 49, !dbg !345 + %9747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 50, !dbg !345 + %9748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 51, !dbg !345 + %9749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 52, !dbg !345 + %9750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 53, !dbg !345 + %9751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 54, !dbg !345 + %9752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 55, !dbg !345 + %9753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 56, !dbg !345 + %9754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 57, !dbg !345 + %9755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 58, !dbg !345 + %9756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 59, !dbg !345 + %9757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 60, !dbg !345 + %9758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 61, !dbg !345 + %9759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 62, !dbg !345 + %9760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9691, 63, !dbg !345 + %9761 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9697, float %9698, float %9699, float %9700, float %9701, float %9702, float %9703, float %9704, float %9705, float %9706, float %9707, float %9708, float %9709, float %9710, float %9711, float %9712, float %9713, float %9714, float %9715, float %9716, float %9717, float %9718, float %9719, float %9720, float %9721, float %9722, float %9723, float %9724, float %9725, float %9726, float %9727, float %9728, float %9729, float %9730, float %9731, float %9732, float %9733, float %9734, float %9735, float %9736, float %9737, float %9738, float %9739, float %9740, float %9741, float %9742, float %9743, float %9744, float %9745, float %9746, float %9747, float %9748, float %9749, float %9750, float %9751, float %9752, float %9753, float %9754, float %9755, float %9756, float %9757, float %9758, float %9759, float %9760, i32 %9608, i32 %9609, i32 %9610, i32 %9611, i64 %9696, i1 true) #3, !dbg !345 + %9762 = add i32 %9616, 6144, !dbg !345 + %9763 = lshr exact i32 %9762, 4, !dbg !345 + %9764 = and i32 %9763, 16383, !dbg !345 + %9765 = zext nneg i32 %9764 to i64, !dbg !345 + %9766 = or disjoint i64 %9765, 4611686293338849280, !dbg !345 + %9767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 0, !dbg !345 + %9768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 1, !dbg !345 + %9769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 2, !dbg !345 + %9770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 3, !dbg !345 + %9771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 4, !dbg !345 + %9772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 5, !dbg !345 + %9773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 6, !dbg !345 + %9774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 7, !dbg !345 + %9775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 8, !dbg !345 + %9776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 9, !dbg !345 + %9777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 10, !dbg !345 + %9778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 11, !dbg !345 + %9779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 12, !dbg !345 + %9780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 13, !dbg !345 + %9781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 14, !dbg !345 + %9782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 15, !dbg !345 + %9783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 16, !dbg !345 + %9784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 17, !dbg !345 + %9785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 18, !dbg !345 + %9786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 19, !dbg !345 + %9787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 20, !dbg !345 + %9788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 21, !dbg !345 + %9789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 22, !dbg !345 + %9790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 23, !dbg !345 + %9791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 24, !dbg !345 + %9792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 25, !dbg !345 + %9793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 26, !dbg !345 + %9794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 27, !dbg !345 + %9795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 28, !dbg !345 + %9796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 29, !dbg !345 + %9797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 30, !dbg !345 + %9798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 31, !dbg !345 + %9799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 32, !dbg !345 + %9800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 33, !dbg !345 + %9801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 34, !dbg !345 + %9802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 35, !dbg !345 + %9803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 36, !dbg !345 + %9804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 37, !dbg !345 + %9805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 38, !dbg !345 + %9806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 39, !dbg !345 + %9807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 40, !dbg !345 + %9808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 41, !dbg !345 + %9809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 42, !dbg !345 + %9810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 43, !dbg !345 + %9811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 44, !dbg !345 + %9812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 45, !dbg !345 + %9813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 46, !dbg !345 + %9814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 47, !dbg !345 + %9815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 48, !dbg !345 + %9816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 49, !dbg !345 + %9817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 50, !dbg !345 + %9818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 51, !dbg !345 + %9819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 52, !dbg !345 + %9820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 53, !dbg !345 + %9821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 54, !dbg !345 + %9822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 55, !dbg !345 + %9823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 56, !dbg !345 + %9824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 57, !dbg !345 + %9825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 58, !dbg !345 + %9826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 59, !dbg !345 + %9827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 60, !dbg !345 + %9828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 61, !dbg !345 + %9829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 62, !dbg !345 + %9830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9761, 63, !dbg !345 + %9831 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9767, float %9768, float %9769, float %9770, float %9771, float %9772, float %9773, float %9774, float %9775, float %9776, float %9777, float %9778, float %9779, float %9780, float %9781, float %9782, float %9783, float %9784, float %9785, float %9786, float %9787, float %9788, float %9789, float %9790, float %9791, float %9792, float %9793, float %9794, float %9795, float %9796, float %9797, float %9798, float %9799, float %9800, float %9801, float %9802, float %9803, float %9804, float %9805, float %9806, float %9807, float %9808, float %9809, float %9810, float %9811, float %9812, float %9813, float %9814, float %9815, float %9816, float %9817, float %9818, float %9819, float %9820, float %9821, float %9822, float %9823, float %9824, float %9825, float %9826, float %9827, float %9828, float %9829, float %9830, i32 %9612, i32 %9613, i32 %9614, i32 %9615, i64 %9766, i1 true) #3, !dbg !345 + %9832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 0, !dbg !345 + %9833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 1, !dbg !345 + %9834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 2, !dbg !345 + %9835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 3, !dbg !345 + %9836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 4, !dbg !345 + %9837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 5, !dbg !345 + %9838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 6, !dbg !345 + %9839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 7, !dbg !345 + %9840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 8, !dbg !345 + %9841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 9, !dbg !345 + %9842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 10, !dbg !345 + %9843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 11, !dbg !345 + %9844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 12, !dbg !345 + %9845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 13, !dbg !345 + %9846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 14, !dbg !345 + %9847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 15, !dbg !345 + %9848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 16, !dbg !345 + %9849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 17, !dbg !345 + %9850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 18, !dbg !345 + %9851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 19, !dbg !345 + %9852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 20, !dbg !345 + %9853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 21, !dbg !345 + %9854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 22, !dbg !345 + %9855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 23, !dbg !345 + %9856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 24, !dbg !345 + %9857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 25, !dbg !345 + %9858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 26, !dbg !345 + %9859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 27, !dbg !345 + %9860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 28, !dbg !345 + %9861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 29, !dbg !345 + %9862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 30, !dbg !345 + %9863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 31, !dbg !345 + %9864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 32, !dbg !345 + %9865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 33, !dbg !345 + %9866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 34, !dbg !345 + %9867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 35, !dbg !345 + %9868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 36, !dbg !345 + %9869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 37, !dbg !345 + %9870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 38, !dbg !345 + %9871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 39, !dbg !345 + %9872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 40, !dbg !345 + %9873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 41, !dbg !345 + %9874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 42, !dbg !345 + %9875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 43, !dbg !345 + %9876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 44, !dbg !345 + %9877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 45, !dbg !345 + %9878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 46, !dbg !345 + %9879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 47, !dbg !345 + %9880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 48, !dbg !345 + %9881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 49, !dbg !345 + %9882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 50, !dbg !345 + %9883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 51, !dbg !345 + %9884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 52, !dbg !345 + %9885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 53, !dbg !345 + %9886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 54, !dbg !345 + %9887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 55, !dbg !345 + %9888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 56, !dbg !345 + %9889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 57, !dbg !345 + %9890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 58, !dbg !345 + %9891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 59, !dbg !345 + %9892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 60, !dbg !345 + %9893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 61, !dbg !345 + %9894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 62, !dbg !345 + %9895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9831, 63, !dbg !345 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !345 + %9896 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %8810, !dbg !331 + %9897 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5189, !dbg !331 + %9898 = load float, ptr addrspace(3) %9897, align 8, !dbg !331 + %9899 = getelementptr inbounds nuw i8, ptr addrspace(3) %9897, i32 4, !dbg !331 + %9900 = load float, ptr addrspace(3) %9899, align 4, !dbg !331 + %9901 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5195, !dbg !331 + %9902 = load float, ptr addrspace(3) %9901, align 8, !dbg !331 + %9903 = getelementptr inbounds nuw i8, ptr addrspace(3) %9901, i32 4, !dbg !331 + %9904 = load float, ptr addrspace(3) %9903, align 4, !dbg !331 + %9905 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5201, !dbg !331 + %9906 = load float, ptr addrspace(3) %9905, align 8, !dbg !331 + %9907 = getelementptr inbounds nuw i8, ptr addrspace(3) %9905, i32 4, !dbg !331 + %9908 = load float, ptr addrspace(3) %9907, align 4, !dbg !331 + %9909 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5207, !dbg !331 + %9910 = load float, ptr addrspace(3) %9909, align 8, !dbg !331 + %9911 = getelementptr inbounds nuw i8, ptr addrspace(3) %9909, i32 4, !dbg !331 + %9912 = load float, ptr addrspace(3) %9911, align 4, !dbg !331 + %9913 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5213, !dbg !331 + %9914 = load float, ptr addrspace(3) %9913, align 8, !dbg !331 + %9915 = getelementptr inbounds nuw i8, ptr addrspace(3) %9913, i32 4, !dbg !331 + %9916 = load float, ptr addrspace(3) %9915, align 4, !dbg !331 + %9917 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5219, !dbg !331 + %9918 = load float, ptr addrspace(3) %9917, align 8, !dbg !331 + %9919 = getelementptr inbounds nuw i8, ptr addrspace(3) %9917, i32 4, !dbg !331 + %9920 = load float, ptr addrspace(3) %9919, align 4, !dbg !331 + %9921 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5225, !dbg !331 + %9922 = load float, ptr addrspace(3) %9921, align 8, !dbg !331 + %9923 = getelementptr inbounds nuw i8, ptr addrspace(3) %9921, i32 4, !dbg !331 + %9924 = load float, ptr addrspace(3) %9923, align 4, !dbg !331 + %9925 = getelementptr inbounds nuw i8, ptr addrspace(3) %9896, i32 %5231, !dbg !331 + %9926 = load float, ptr addrspace(3) %9925, align 8, !dbg !331 + %9927 = getelementptr inbounds nuw i8, ptr addrspace(3) %9925, i32 4, !dbg !331 + %9928 = load float, ptr addrspace(3) %9927, align 4, !dbg !331 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !346 + %9929 = add i32 %8878, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %9930 = lshr exact i32 %9929, 4, !dbg !346 + %9931 = and i32 %9930, 16383, !dbg !346 + %9932 = zext nneg i32 %9931 to i64, !dbg !346 + %9933 = or disjoint i64 %9932, 4611686293372403712, !dbg !346 + %9934 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %9933, i64 %9620) #3, !dbg !346 + %9935 = add i32 %8890, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %9936 = lshr exact i32 %9935, 4, !dbg !346 + %9937 = and i32 %9936, 16383, !dbg !346 + %9938 = zext nneg i32 %9937 to i64, !dbg !346 + %9939 = or disjoint i64 %9938, 4611686293372403712, !dbg !346 + %9940 = add i32 %9616, 32, !dbg !346 + %9941 = lshr exact i32 %9940, 4, !dbg !346 + %9942 = and i32 %9941, 16383, !dbg !346 + %9943 = zext nneg i32 %9942 to i64, !dbg !346 + %9944 = or disjoint i64 %9943, 4611686293338849280, !dbg !346 + %9945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 0, !dbg !346 + %9946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 1, !dbg !346 + %9947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 2, !dbg !346 + %9948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 3, !dbg !346 + %9949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 4, !dbg !346 + %9950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 5, !dbg !346 + %9951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 6, !dbg !346 + %9952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 7, !dbg !346 + %9953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 8, !dbg !346 + %9954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 9, !dbg !346 + %9955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 10, !dbg !346 + %9956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 11, !dbg !346 + %9957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 12, !dbg !346 + %9958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 13, !dbg !346 + %9959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 14, !dbg !346 + %9960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 15, !dbg !346 + %9961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 16, !dbg !346 + %9962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 17, !dbg !346 + %9963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 18, !dbg !346 + %9964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 19, !dbg !346 + %9965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 20, !dbg !346 + %9966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 21, !dbg !346 + %9967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 22, !dbg !346 + %9968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 23, !dbg !346 + %9969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 24, !dbg !346 + %9970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 25, !dbg !346 + %9971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 26, !dbg !346 + %9972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 27, !dbg !346 + %9973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 28, !dbg !346 + %9974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 29, !dbg !346 + %9975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 30, !dbg !346 + %9976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9934, 31, !dbg !346 + %9977 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9945, float %9946, float %9947, float %9948, float %9949, float %9950, float %9951, float %9952, float %9953, float %9954, float %9955, float %9956, float %9957, float %9958, float %9959, float %9960, float %9961, float %9962, float %9963, float %9964, float %9965, float %9966, float %9967, float %9968, float %9969, float %9970, float %9971, float %9972, float %9973, float %9974, float %9975, float %9976, i64 %9939, i64 %9944, i1 true) #3, !dbg !346 + %9978 = add i32 %8934, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %9979 = lshr exact i32 %9978, 4, !dbg !346 + %9980 = and i32 %9979, 16383, !dbg !346 + %9981 = zext nneg i32 %9980 to i64, !dbg !346 + %9982 = or disjoint i64 %9981, 4611686293372403712, !dbg !346 + %9983 = add i32 %9616, 64, !dbg !346 + %9984 = lshr exact i32 %9983, 4, !dbg !346 + %9985 = and i32 %9984, 16383, !dbg !346 + %9986 = zext nneg i32 %9985 to i64, !dbg !346 + %9987 = or disjoint i64 %9986, 4611686293338849280, !dbg !346 + %9988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 0, !dbg !346 + %9989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 1, !dbg !346 + %9990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 2, !dbg !346 + %9991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 3, !dbg !346 + %9992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 4, !dbg !346 + %9993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 5, !dbg !346 + %9994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 6, !dbg !346 + %9995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 7, !dbg !346 + %9996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 8, !dbg !346 + %9997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 9, !dbg !346 + %9998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 10, !dbg !346 + %9999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 11, !dbg !346 + %10000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 12, !dbg !346 + %10001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 13, !dbg !346 + %10002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 14, !dbg !346 + %10003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 15, !dbg !346 + %10004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 16, !dbg !346 + %10005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 17, !dbg !346 + %10006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 18, !dbg !346 + %10007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 19, !dbg !346 + %10008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 20, !dbg !346 + %10009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 21, !dbg !346 + %10010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 22, !dbg !346 + %10011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 23, !dbg !346 + %10012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 24, !dbg !346 + %10013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 25, !dbg !346 + %10014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 26, !dbg !346 + %10015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 27, !dbg !346 + %10016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 28, !dbg !346 + %10017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 29, !dbg !346 + %10018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 30, !dbg !346 + %10019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9977, 31, !dbg !346 + %10020 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9988, float %9989, float %9990, float %9991, float %9992, float %9993, float %9994, float %9995, float %9996, float %9997, float %9998, float %9999, float %10000, float %10001, float %10002, float %10003, float %10004, float %10005, float %10006, float %10007, float %10008, float %10009, float %10010, float %10011, float %10012, float %10013, float %10014, float %10015, float %10016, float %10017, float %10018, float %10019, i64 %9982, i64 %9987, i1 true) #3, !dbg !346 + %10021 = add i32 %8978, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %10022 = lshr exact i32 %10021, 4, !dbg !346 + %10023 = and i32 %10022, 16383, !dbg !346 + %10024 = zext nneg i32 %10023 to i64, !dbg !346 + %10025 = or disjoint i64 %10024, 4611686293372403712, !dbg !346 + %10026 = add i32 %9616, 96, !dbg !346 + %10027 = lshr exact i32 %10026, 4, !dbg !346 + %10028 = and i32 %10027, 16383, !dbg !346 + %10029 = zext nneg i32 %10028 to i64, !dbg !346 + %10030 = or disjoint i64 %10029, 4611686293338849280, !dbg !346 + %10031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 0, !dbg !346 + %10032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 1, !dbg !346 + %10033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 2, !dbg !346 + %10034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 3, !dbg !346 + %10035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 4, !dbg !346 + %10036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 5, !dbg !346 + %10037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 6, !dbg !346 + %10038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 7, !dbg !346 + %10039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 8, !dbg !346 + %10040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 9, !dbg !346 + %10041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 10, !dbg !346 + %10042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 11, !dbg !346 + %10043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 12, !dbg !346 + %10044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 13, !dbg !346 + %10045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 14, !dbg !346 + %10046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 15, !dbg !346 + %10047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 16, !dbg !346 + %10048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 17, !dbg !346 + %10049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 18, !dbg !346 + %10050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 19, !dbg !346 + %10051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 20, !dbg !346 + %10052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 21, !dbg !346 + %10053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 22, !dbg !346 + %10054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 23, !dbg !346 + %10055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 24, !dbg !346 + %10056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 25, !dbg !346 + %10057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 26, !dbg !346 + %10058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 27, !dbg !346 + %10059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 28, !dbg !346 + %10060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 29, !dbg !346 + %10061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 30, !dbg !346 + %10062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10020, 31, !dbg !346 + %10063 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10031, float %10032, float %10033, float %10034, float %10035, float %10036, float %10037, float %10038, float %10039, float %10040, float %10041, float %10042, float %10043, float %10044, float %10045, float %10046, float %10047, float %10048, float %10049, float %10050, float %10051, float %10052, float %10053, float %10054, float %10055, float %10056, float %10057, float %10058, float %10059, float %10060, float %10061, float %10062, i64 %10025, i64 %10030, i1 true) #3, !dbg !346 + %10064 = add i32 %9022, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %10065 = lshr exact i32 %10064, 4, !dbg !346 + %10066 = and i32 %10065, 16383, !dbg !346 + %10067 = zext nneg i32 %10066 to i64, !dbg !346 + %10068 = or disjoint i64 %10067, 4611686293372403712, !dbg !346 + %10069 = add i32 %9616, 8192, !dbg !346 + %10070 = lshr exact i32 %10069, 4, !dbg !346 + %10071 = and i32 %10070, 16383, !dbg !346 + %10072 = zext nneg i32 %10071 to i64, !dbg !346 + %10073 = or disjoint i64 %10072, 4611686293338849280, !dbg !346 + %10074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 0, !dbg !346 + %10075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 1, !dbg !346 + %10076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 2, !dbg !346 + %10077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 3, !dbg !346 + %10078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 4, !dbg !346 + %10079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 5, !dbg !346 + %10080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 6, !dbg !346 + %10081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 7, !dbg !346 + %10082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 8, !dbg !346 + %10083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 9, !dbg !346 + %10084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 10, !dbg !346 + %10085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 11, !dbg !346 + %10086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 12, !dbg !346 + %10087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 13, !dbg !346 + %10088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 14, !dbg !346 + %10089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 15, !dbg !346 + %10090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 16, !dbg !346 + %10091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 17, !dbg !346 + %10092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 18, !dbg !346 + %10093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 19, !dbg !346 + %10094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 20, !dbg !346 + %10095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 21, !dbg !346 + %10096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 22, !dbg !346 + %10097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 23, !dbg !346 + %10098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 24, !dbg !346 + %10099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 25, !dbg !346 + %10100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 26, !dbg !346 + %10101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 27, !dbg !346 + %10102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 28, !dbg !346 + %10103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 29, !dbg !346 + %10104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 30, !dbg !346 + %10105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10063, 31, !dbg !346 + %10106 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10074, float %10075, float %10076, float %10077, float %10078, float %10079, float %10080, float %10081, float %10082, float %10083, float %10084, float %10085, float %10086, float %10087, float %10088, float %10089, float %10090, float %10091, float %10092, float %10093, float %10094, float %10095, float %10096, float %10097, float %10098, float %10099, float %10100, float %10101, float %10102, float %10103, float %10104, float %10105, i64 %10068, i64 %10073, i1 true) #3, !dbg !346 + %10107 = add i32 %9066, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %10108 = lshr exact i32 %10107, 4, !dbg !346 + %10109 = and i32 %10108, 16383, !dbg !346 + %10110 = zext nneg i32 %10109 to i64, !dbg !346 + %10111 = or disjoint i64 %10110, 4611686293372403712, !dbg !346 + %10112 = add i32 %9616, 8224, !dbg !346 + %10113 = lshr exact i32 %10112, 4, !dbg !346 + %10114 = and i32 %10113, 16383, !dbg !346 + %10115 = zext nneg i32 %10114 to i64, !dbg !346 + %10116 = or disjoint i64 %10115, 4611686293338849280, !dbg !346 + %10117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 0, !dbg !346 + %10118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 1, !dbg !346 + %10119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 2, !dbg !346 + %10120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 3, !dbg !346 + %10121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 4, !dbg !346 + %10122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 5, !dbg !346 + %10123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 6, !dbg !346 + %10124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 7, !dbg !346 + %10125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 8, !dbg !346 + %10126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 9, !dbg !346 + %10127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 10, !dbg !346 + %10128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 11, !dbg !346 + %10129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 12, !dbg !346 + %10130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 13, !dbg !346 + %10131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 14, !dbg !346 + %10132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 15, !dbg !346 + %10133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 16, !dbg !346 + %10134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 17, !dbg !346 + %10135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 18, !dbg !346 + %10136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 19, !dbg !346 + %10137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 20, !dbg !346 + %10138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 21, !dbg !346 + %10139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 22, !dbg !346 + %10140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 23, !dbg !346 + %10141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 24, !dbg !346 + %10142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 25, !dbg !346 + %10143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 26, !dbg !346 + %10144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 27, !dbg !346 + %10145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 28, !dbg !346 + %10146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 29, !dbg !346 + %10147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 30, !dbg !346 + %10148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10106, 31, !dbg !346 + %10149 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10117, float %10118, float %10119, float %10120, float %10121, float %10122, float %10123, float %10124, float %10125, float %10126, float %10127, float %10128, float %10129, float %10130, float %10131, float %10132, float %10133, float %10134, float %10135, float %10136, float %10137, float %10138, float %10139, float %10140, float %10141, float %10142, float %10143, float %10144, float %10145, float %10146, float %10147, float %10148, i64 %10111, i64 %10116, i1 true) #3, !dbg !346 + %10150 = add i32 %9110, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %10151 = lshr exact i32 %10150, 4, !dbg !346 + %10152 = and i32 %10151, 16383, !dbg !346 + %10153 = zext nneg i32 %10152 to i64, !dbg !346 + %10154 = or disjoint i64 %10153, 4611686293372403712, !dbg !346 + %10155 = add i32 %9616, 8256, !dbg !346 + %10156 = lshr exact i32 %10155, 4, !dbg !346 + %10157 = and i32 %10156, 16383, !dbg !346 + %10158 = zext nneg i32 %10157 to i64, !dbg !346 + %10159 = or disjoint i64 %10158, 4611686293338849280, !dbg !346 + %10160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 0, !dbg !346 + %10161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 1, !dbg !346 + %10162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 2, !dbg !346 + %10163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 3, !dbg !346 + %10164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 4, !dbg !346 + %10165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 5, !dbg !346 + %10166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 6, !dbg !346 + %10167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 7, !dbg !346 + %10168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 8, !dbg !346 + %10169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 9, !dbg !346 + %10170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 10, !dbg !346 + %10171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 11, !dbg !346 + %10172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 12, !dbg !346 + %10173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 13, !dbg !346 + %10174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 14, !dbg !346 + %10175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 15, !dbg !346 + %10176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 16, !dbg !346 + %10177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 17, !dbg !346 + %10178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 18, !dbg !346 + %10179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 19, !dbg !346 + %10180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 20, !dbg !346 + %10181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 21, !dbg !346 + %10182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 22, !dbg !346 + %10183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 23, !dbg !346 + %10184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 24, !dbg !346 + %10185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 25, !dbg !346 + %10186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 26, !dbg !346 + %10187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 27, !dbg !346 + %10188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 28, !dbg !346 + %10189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 29, !dbg !346 + %10190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 30, !dbg !346 + %10191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10149, 31, !dbg !346 + %10192 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10160, float %10161, float %10162, float %10163, float %10164, float %10165, float %10166, float %10167, float %10168, float %10169, float %10170, float %10171, float %10172, float %10173, float %10174, float %10175, float %10176, float %10177, float %10178, float %10179, float %10180, float %10181, float %10182, float %10183, float %10184, float %10185, float %10186, float %10187, float %10188, float %10189, float %10190, float %10191, i64 %10154, i64 %10159, i1 true) #3, !dbg !346 + %10193 = add i32 %9154, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !346 + %10194 = lshr exact i32 %10193, 4, !dbg !346 + %10195 = and i32 %10194, 16383, !dbg !346 + %10196 = zext nneg i32 %10195 to i64, !dbg !346 + %10197 = or disjoint i64 %10196, 4611686293372403712, !dbg !346 + %10198 = add i32 %9616, 8288, !dbg !346 + %10199 = lshr exact i32 %10198, 4, !dbg !346 + %10200 = and i32 %10199, 16383, !dbg !346 + %10201 = zext nneg i32 %10200 to i64, !dbg !346 + %10202 = or disjoint i64 %10201, 4611686293338849280, !dbg !346 + %10203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 0, !dbg !346 + %10204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 1, !dbg !346 + %10205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 2, !dbg !346 + %10206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 3, !dbg !346 + %10207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 4, !dbg !346 + %10208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 5, !dbg !346 + %10209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 6, !dbg !346 + %10210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 7, !dbg !346 + %10211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 8, !dbg !346 + %10212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 9, !dbg !346 + %10213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 10, !dbg !346 + %10214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 11, !dbg !346 + %10215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 12, !dbg !346 + %10216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 13, !dbg !346 + %10217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 14, !dbg !346 + %10218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 15, !dbg !346 + %10219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 16, !dbg !346 + %10220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 17, !dbg !346 + %10221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 18, !dbg !346 + %10222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 19, !dbg !346 + %10223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 20, !dbg !346 + %10224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 21, !dbg !346 + %10225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 22, !dbg !346 + %10226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 23, !dbg !346 + %10227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 24, !dbg !346 + %10228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 25, !dbg !346 + %10229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 26, !dbg !346 + %10230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 27, !dbg !346 + %10231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 28, !dbg !346 + %10232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 29, !dbg !346 + %10233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 30, !dbg !346 + %10234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10192, 31, !dbg !346 + %10235 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10203, float %10204, float %10205, float %10206, float %10207, float %10208, float %10209, float %10210, float %10211, float %10212, float %10213, float %10214, float %10215, float %10216, float %10217, float %10218, float %10219, float %10220, float %10221, float %10222, float %10223, float %10224, float %10225, float %10226, float %10227, float %10228, float %10229, float %10230, float %10231, float %10232, float %10233, float %10234, i64 %10197, i64 %10202, i1 true) #3, !dbg !346 + %10236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 0, !dbg !346 + %10237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 1, !dbg !346 + %10238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 2, !dbg !346 + %10239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 3, !dbg !346 + %10240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 4, !dbg !346 + %10241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 5, !dbg !346 + %10242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 6, !dbg !346 + %10243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 7, !dbg !346 + %10244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 8, !dbg !346 + %10245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 9, !dbg !346 + %10246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 10, !dbg !346 + %10247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 11, !dbg !346 + %10248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 12, !dbg !346 + %10249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 13, !dbg !346 + %10250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 14, !dbg !346 + %10251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 15, !dbg !346 + %10252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 16, !dbg !346 + %10253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 17, !dbg !346 + %10254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 18, !dbg !346 + %10255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 19, !dbg !346 + %10256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 20, !dbg !346 + %10257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 21, !dbg !346 + %10258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 22, !dbg !346 + %10259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 23, !dbg !346 + %10260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 24, !dbg !346 + %10261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 25, !dbg !346 + %10262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 26, !dbg !346 + %10263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 27, !dbg !346 + %10264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 28, !dbg !346 + %10265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 29, !dbg !346 + %10266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 30, !dbg !346 + %10267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10235, 31, !dbg !346 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !346 + %10268 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %10236, float %10237, float %10238, float %10239, float %10240, float %10241, float %10242, float %10243, float %10244, float %10245, float %10246, float %10247, float %10248, float %10249, float %10250, float %10251, float %10252, float %10253, float %10254, float %10255, float %10256, float %10257, float %10258, float %10259, float %10260, float %10261, float %10262, float %10263, float %10264, float %10265, float %10266, float %10267, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %9551, i32 0, i32 0) #3, !dbg !346 + %10269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 0, !dbg !346 + %10270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 1, !dbg !346 + %10271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 2, !dbg !346 + %10272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 3, !dbg !346 + %10273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 4, !dbg !346 + %10274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 5, !dbg !346 + %10275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 6, !dbg !346 + %10276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 7, !dbg !346 + %10277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 8, !dbg !346 + %10278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 9, !dbg !346 + %10279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 10, !dbg !346 + %10280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 11, !dbg !346 + %10281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 12, !dbg !346 + %10282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 13, !dbg !346 + %10283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 14, !dbg !346 + %10284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 15, !dbg !346 + %10285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 16, !dbg !346 + %10286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 17, !dbg !346 + %10287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 18, !dbg !346 + %10288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 19, !dbg !346 + %10289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 20, !dbg !346 + %10290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 21, !dbg !346 + %10291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 22, !dbg !346 + %10292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 23, !dbg !346 + %10293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 24, !dbg !346 + %10294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 25, !dbg !346 + %10295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 26, !dbg !346 + %10296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 27, !dbg !346 + %10297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 28, !dbg !346 + %10298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 29, !dbg !346 + %10299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 30, !dbg !346 + %10300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10268, 31, !dbg !346 + %10301 = fsub float %10269, %9898, !dbg !347 + %10302 = fsub float %10270, %9900, !dbg !347 + %10303 = fsub float %10271, %9898, !dbg !347 + %10304 = fsub float %10272, %9900, !dbg !347 + %10305 = fsub float %10273, %9902, !dbg !347 + %10306 = fsub float %10274, %9904, !dbg !347 + %10307 = fsub float %10275, %9902, !dbg !347 + %10308 = fsub float %10276, %9904, !dbg !347 + %10309 = fsub float %10277, %9906, !dbg !347 + %10310 = fsub float %10278, %9908, !dbg !347 + %10311 = fsub float %10279, %9906, !dbg !347 + %10312 = fsub float %10280, %9908, !dbg !347 + %10313 = fsub float %10281, %9910, !dbg !347 + %10314 = fsub float %10282, %9912, !dbg !347 + %10315 = fsub float %10283, %9910, !dbg !347 + %10316 = fsub float %10284, %9912, !dbg !347 + %10317 = fsub float %10285, %9914, !dbg !347 + %10318 = fsub float %10286, %9916, !dbg !347 + %10319 = fsub float %10287, %9914, !dbg !347 + %10320 = fsub float %10288, %9916, !dbg !347 + %10321 = fsub float %10289, %9918, !dbg !347 + %10322 = fsub float %10290, %9920, !dbg !347 + %10323 = fsub float %10291, %9918, !dbg !347 + %10324 = fsub float %10292, %9920, !dbg !347 + %10325 = fsub float %10293, %9922, !dbg !347 + %10326 = fsub float %10294, %9924, !dbg !347 + %10327 = fsub float %10295, %9922, !dbg !347 + %10328 = fsub float %10296, %9924, !dbg !347 + %10329 = fsub float %10297, %9926, !dbg !347 + %10330 = fsub float %10298, %9928, !dbg !347 + %10331 = fsub float %10299, %9926, !dbg !347 + %10332 = fsub float %10300, %9928, !dbg !347 + %10333 = fmul float %.0.i, %10301, !dbg !348 + %10334 = fmul float %.0.i1230, %10302, !dbg !348 + %10335 = fmul float %.0.i1233, %10303, !dbg !348 + %10336 = fmul float %.0.i1236, %10304, !dbg !348 + %10337 = fmul float %.0.i1239, %10305, !dbg !348 + %10338 = fmul float %.0.i1242, %10306, !dbg !348 + %10339 = fmul float %.0.i1245, %10307, !dbg !348 + %10340 = fmul float %.0.i1248, %10308, !dbg !348 + %10341 = fmul float %.0.i1251, %10309, !dbg !348 + %10342 = fmul float %.0.i1254, %10310, !dbg !348 + %10343 = fmul float %.0.i1257, %10311, !dbg !348 + %10344 = fmul float %.0.i1260, %10312, !dbg !348 + %10345 = fmul float %.0.i1263, %10313, !dbg !348 + %10346 = fmul float %.0.i1266, %10314, !dbg !348 + %10347 = fmul float %.0.i1269, %10315, !dbg !348 + %10348 = fmul float %.0.i1272, %10316, !dbg !348 + %10349 = fmul float %.0.i1275, %10317, !dbg !348 + %10350 = fmul float %.0.i1278, %10318, !dbg !348 + %10351 = fmul float %.0.i1281, %10319, !dbg !348 + %10352 = fmul float %.0.i1284, %10320, !dbg !348 + %10353 = fmul float %.0.i1287, %10321, !dbg !348 + %10354 = fmul float %.0.i1290, %10322, !dbg !348 + %10355 = fmul float %.0.i1293, %10323, !dbg !348 + %10356 = fmul float %.0.i1296, %10324, !dbg !348 + %10357 = fmul float %.0.i1299, %10325, !dbg !348 + %10358 = fmul float %.0.i1302, %10326, !dbg !348 + %10359 = fmul float %.0.i1305, %10327, !dbg !348 + %10360 = fmul float %.0.i1308, %10328, !dbg !348 + %10361 = fmul float %.0.i1311, %10329, !dbg !348 + %10362 = fmul float %.0.i1314, %10330, !dbg !348 + %10363 = fmul float %.0.i1317, %10331, !dbg !348 + %10364 = fmul float %.0.i1320, %10332, !dbg !348 + %10365 = fptrunc float %10333 to bfloat, !dbg !349 + %10366 = select i1 %8792, bfloat %10365, bfloat 0xR0000, !dbg !350 + %10367 = fptrunc float %10334 to bfloat, !dbg !349 + %10368 = select i1 %8793, bfloat %10367, bfloat 0xR0000, !dbg !350 + %10369 = fptrunc float %10335 to bfloat, !dbg !349 + %10370 = select i1 %8792, bfloat %10369, bfloat 0xR0000, !dbg !350 + %10371 = fptrunc float %10336 to bfloat, !dbg !349 + %10372 = select i1 %8793, bfloat %10371, bfloat 0xR0000, !dbg !350 + %10373 = fptrunc float %10337 to bfloat, !dbg !349 + %10374 = select i1 %8794, bfloat %10373, bfloat 0xR0000, !dbg !350 + %10375 = fptrunc float %10338 to bfloat, !dbg !349 + %10376 = select i1 %8795, bfloat %10375, bfloat 0xR0000, !dbg !350 + %10377 = fptrunc float %10339 to bfloat, !dbg !349 + %10378 = select i1 %8794, bfloat %10377, bfloat 0xR0000, !dbg !350 + %10379 = fptrunc float %10340 to bfloat, !dbg !349 + %10380 = select i1 %8795, bfloat %10379, bfloat 0xR0000, !dbg !350 + %10381 = fptrunc float %10341 to bfloat, !dbg !349 + %10382 = select i1 %8796, bfloat %10381, bfloat 0xR0000, !dbg !350 + %10383 = fptrunc float %10342 to bfloat, !dbg !349 + %10384 = select i1 %8797, bfloat %10383, bfloat 0xR0000, !dbg !350 + %10385 = fptrunc float %10343 to bfloat, !dbg !349 + %10386 = select i1 %8796, bfloat %10385, bfloat 0xR0000, !dbg !350 + %10387 = fptrunc float %10344 to bfloat, !dbg !349 + %10388 = select i1 %8797, bfloat %10387, bfloat 0xR0000, !dbg !350 + %10389 = fptrunc float %10345 to bfloat, !dbg !349 + %10390 = select i1 %8798, bfloat %10389, bfloat 0xR0000, !dbg !350 + %10391 = fptrunc float %10346 to bfloat, !dbg !349 + %10392 = select i1 %8799, bfloat %10391, bfloat 0xR0000, !dbg !350 + %10393 = fptrunc float %10347 to bfloat, !dbg !349 + %10394 = select i1 %8798, bfloat %10393, bfloat 0xR0000, !dbg !350 + %10395 = fptrunc float %10348 to bfloat, !dbg !349 + %10396 = select i1 %8799, bfloat %10395, bfloat 0xR0000, !dbg !350 + %10397 = fptrunc float %10349 to bfloat, !dbg !349 + %10398 = select i1 %8800, bfloat %10397, bfloat 0xR0000, !dbg !350 + %10399 = fptrunc float %10350 to bfloat, !dbg !349 + %10400 = select i1 %8801, bfloat %10399, bfloat 0xR0000, !dbg !350 + %10401 = fptrunc float %10351 to bfloat, !dbg !349 + %10402 = select i1 %8800, bfloat %10401, bfloat 0xR0000, !dbg !350 + %10403 = fptrunc float %10352 to bfloat, !dbg !349 + %10404 = select i1 %8801, bfloat %10403, bfloat 0xR0000, !dbg !350 + %10405 = fptrunc float %10353 to bfloat, !dbg !349 + %10406 = select i1 %8802, bfloat %10405, bfloat 0xR0000, !dbg !350 + %10407 = fptrunc float %10354 to bfloat, !dbg !349 + %10408 = select i1 %8803, bfloat %10407, bfloat 0xR0000, !dbg !350 + %10409 = fptrunc float %10355 to bfloat, !dbg !349 + %10410 = select i1 %8802, bfloat %10409, bfloat 0xR0000, !dbg !350 + %10411 = fptrunc float %10356 to bfloat, !dbg !349 + %10412 = select i1 %8803, bfloat %10411, bfloat 0xR0000, !dbg !350 + %10413 = fptrunc float %10357 to bfloat, !dbg !349 + %10414 = select i1 %8804, bfloat %10413, bfloat 0xR0000, !dbg !350 + %10415 = fptrunc float %10358 to bfloat, !dbg !349 + %10416 = select i1 %8805, bfloat %10415, bfloat 0xR0000, !dbg !350 + %10417 = fptrunc float %10359 to bfloat, !dbg !349 + %10418 = select i1 %8804, bfloat %10417, bfloat 0xR0000, !dbg !350 + %10419 = fptrunc float %10360 to bfloat, !dbg !349 + %10420 = select i1 %8805, bfloat %10419, bfloat 0xR0000, !dbg !350 + %10421 = fptrunc float %10361 to bfloat, !dbg !349 + %10422 = select i1 %8806, bfloat %10421, bfloat 0xR0000, !dbg !350 + %10423 = fptrunc float %10362 to bfloat, !dbg !349 + %10424 = select i1 %8807, bfloat %10423, bfloat 0xR0000, !dbg !350 + %10425 = fptrunc float %10363 to bfloat, !dbg !349 + %10426 = select i1 %8806, bfloat %10425, bfloat 0xR0000, !dbg !350 + %10427 = fptrunc float %10364 to bfloat, !dbg !349 + %10428 = select i1 %8807, bfloat %10427, bfloat 0xR0000, !dbg !350 + %10429 = insertelement <2 x bfloat> poison, bfloat %10366, i64 0, !dbg !351 + %10430 = insertelement <2 x bfloat> %10429, bfloat %10368, i64 1, !dbg !351 + %10431 = bitcast <2 x bfloat> %10430 to i32, !dbg !351 + %10432 = insertelement <2 x bfloat> poison, bfloat %10370, i64 0, !dbg !351 + %10433 = insertelement <2 x bfloat> %10432, bfloat %10372, i64 1, !dbg !351 + %10434 = bitcast <2 x bfloat> %10433 to i32, !dbg !351 + %10435 = insertelement <2 x bfloat> poison, bfloat %10374, i64 0, !dbg !351 + %10436 = insertelement <2 x bfloat> %10435, bfloat %10376, i64 1, !dbg !351 + %10437 = bitcast <2 x bfloat> %10436 to i32, !dbg !351 + %10438 = insertelement <2 x bfloat> poison, bfloat %10378, i64 0, !dbg !351 + %10439 = insertelement <2 x bfloat> %10438, bfloat %10380, i64 1, !dbg !351 + %10440 = bitcast <2 x bfloat> %10439 to i32, !dbg !351 + %10441 = insertelement <2 x bfloat> poison, bfloat %10382, i64 0, !dbg !351 + %10442 = insertelement <2 x bfloat> %10441, bfloat %10384, i64 1, !dbg !351 + %10443 = bitcast <2 x bfloat> %10442 to i32, !dbg !351 + %10444 = insertelement <2 x bfloat> poison, bfloat %10386, i64 0, !dbg !351 + %10445 = insertelement <2 x bfloat> %10444, bfloat %10388, i64 1, !dbg !351 + %10446 = bitcast <2 x bfloat> %10445 to i32, !dbg !351 + %10447 = insertelement <2 x bfloat> poison, bfloat %10390, i64 0, !dbg !351 + %10448 = insertelement <2 x bfloat> %10447, bfloat %10392, i64 1, !dbg !351 + %10449 = bitcast <2 x bfloat> %10448 to i32, !dbg !351 + %10450 = insertelement <2 x bfloat> poison, bfloat %10394, i64 0, !dbg !351 + %10451 = insertelement <2 x bfloat> %10450, bfloat %10396, i64 1, !dbg !351 + %10452 = bitcast <2 x bfloat> %10451 to i32, !dbg !351 + %10453 = insertelement <2 x bfloat> poison, bfloat %10398, i64 0, !dbg !351 + %10454 = insertelement <2 x bfloat> %10453, bfloat %10400, i64 1, !dbg !351 + %10455 = bitcast <2 x bfloat> %10454 to i32, !dbg !351 + %10456 = insertelement <2 x bfloat> poison, bfloat %10402, i64 0, !dbg !351 + %10457 = insertelement <2 x bfloat> %10456, bfloat %10404, i64 1, !dbg !351 + %10458 = bitcast <2 x bfloat> %10457 to i32, !dbg !351 + %10459 = insertelement <2 x bfloat> poison, bfloat %10406, i64 0, !dbg !351 + %10460 = insertelement <2 x bfloat> %10459, bfloat %10408, i64 1, !dbg !351 + %10461 = bitcast <2 x bfloat> %10460 to i32, !dbg !351 + %10462 = insertelement <2 x bfloat> poison, bfloat %10410, i64 0, !dbg !351 + %10463 = insertelement <2 x bfloat> %10462, bfloat %10412, i64 1, !dbg !351 + %10464 = bitcast <2 x bfloat> %10463 to i32, !dbg !351 + %10465 = insertelement <2 x bfloat> poison, bfloat %10414, i64 0, !dbg !351 + %10466 = insertelement <2 x bfloat> %10465, bfloat %10416, i64 1, !dbg !351 + %10467 = bitcast <2 x bfloat> %10466 to i32, !dbg !351 + %10468 = insertelement <2 x bfloat> poison, bfloat %10418, i64 0, !dbg !351 + %10469 = insertelement <2 x bfloat> %10468, bfloat %10420, i64 1, !dbg !351 + %10470 = bitcast <2 x bfloat> %10469 to i32, !dbg !351 + %10471 = insertelement <2 x bfloat> poison, bfloat %10422, i64 0, !dbg !351 + %10472 = insertelement <2 x bfloat> %10471, bfloat %10424, i64 1, !dbg !351 + %10473 = bitcast <2 x bfloat> %10472 to i32, !dbg !351 + %10474 = insertelement <2 x bfloat> poison, bfloat %10426, i64 0, !dbg !351 + %10475 = insertelement <2 x bfloat> %10474, bfloat %10428, i64 1, !dbg !351 + %10476 = bitcast <2 x bfloat> %10475 to i32, !dbg !351 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !351 + %10477 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn3471704, float %.pn3451705, float %.pn3431706, float %.pn3411707, float %.pn3391708, float %.pn3371709, float %.pn3351710, float %.pn3331711, float %.pn3311712, float %.pn3291713, float %.pn3271714, float %.pn3251715, float %.pn3231716, float %.pn3211717, float %.pn3191718, float %.pn3171719, float %.pn3151720, float %.pn3131721, float %.pn3111722, float %.pn3091723, float %.pn3071724, float %.pn3051725, float %.pn3031726, float %.pn3011727, float %.pn2991728, float %.pn2971729, float %.pn2951730, float %.pn2931731, float %.pn2911732, float %.pn2891733, float %.pn2871734, float %.pn2851735, float %.pn2831736, float %.pn2811737, float %.pn2791738, float %.pn2771739, float %.pn2751740, float %.pn2731741, float %.pn2711742, float %.pn2691743, float %.pn2671744, float %.pn2651745, float %.pn2631746, float %.pn2611747, float %.pn2591748, float %.pn2571749, float %.pn2551750, float %.pn2531751, float %.pn2511752, float %.pn2491753, float %.pn2471754, float %.pn2451755, float %.pn2431756, float %.pn2411757, float %.pn2391758, float %.pn2371759, float %.pn2351760, float %.pn2331761, float %.pn2311762, float %.pn2291763, float %.pn2271764, float %.pn2251765, float %.pn2231766, float %.pn2211767, i32 %10431, i32 %10434, i32 %10437, i32 %10440, i64 %8888, i1 true) #3, !dbg !351 + %10478 = add i32 %8884, 2048, !dbg !351 + %10479 = lshr exact i32 %10478, 4, !dbg !351 + %10480 = and i32 %10479, 16383, !dbg !351 + %10481 = zext nneg i32 %10480 to i64, !dbg !351 + %10482 = or disjoint i64 %10481, 4611686293338849280, !dbg !351 + %10483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 0, !dbg !351 + %10484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 1, !dbg !351 + %10485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 2, !dbg !351 + %10486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 3, !dbg !351 + %10487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 4, !dbg !351 + %10488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 5, !dbg !351 + %10489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 6, !dbg !351 + %10490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 7, !dbg !351 + %10491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 8, !dbg !351 + %10492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 9, !dbg !351 + %10493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 10, !dbg !351 + %10494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 11, !dbg !351 + %10495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 12, !dbg !351 + %10496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 13, !dbg !351 + %10497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 14, !dbg !351 + %10498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 15, !dbg !351 + %10499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 16, !dbg !351 + %10500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 17, !dbg !351 + %10501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 18, !dbg !351 + %10502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 19, !dbg !351 + %10503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 20, !dbg !351 + %10504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 21, !dbg !351 + %10505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 22, !dbg !351 + %10506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 23, !dbg !351 + %10507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 24, !dbg !351 + %10508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 25, !dbg !351 + %10509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 26, !dbg !351 + %10510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 27, !dbg !351 + %10511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 28, !dbg !351 + %10512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 29, !dbg !351 + %10513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 30, !dbg !351 + %10514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 31, !dbg !351 + %10515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 32, !dbg !351 + %10516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 33, !dbg !351 + %10517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 34, !dbg !351 + %10518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 35, !dbg !351 + %10519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 36, !dbg !351 + %10520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 37, !dbg !351 + %10521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 38, !dbg !351 + %10522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 39, !dbg !351 + %10523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 40, !dbg !351 + %10524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 41, !dbg !351 + %10525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 42, !dbg !351 + %10526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 43, !dbg !351 + %10527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 44, !dbg !351 + %10528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 45, !dbg !351 + %10529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 46, !dbg !351 + %10530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 47, !dbg !351 + %10531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 48, !dbg !351 + %10532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 49, !dbg !351 + %10533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 50, !dbg !351 + %10534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 51, !dbg !351 + %10535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 52, !dbg !351 + %10536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 53, !dbg !351 + %10537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 54, !dbg !351 + %10538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 55, !dbg !351 + %10539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 56, !dbg !351 + %10540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 57, !dbg !351 + %10541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 58, !dbg !351 + %10542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 59, !dbg !351 + %10543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 60, !dbg !351 + %10544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 61, !dbg !351 + %10545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 62, !dbg !351 + %10546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10477, 63, !dbg !351 + %10547 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10483, float %10484, float %10485, float %10486, float %10487, float %10488, float %10489, float %10490, float %10491, float %10492, float %10493, float %10494, float %10495, float %10496, float %10497, float %10498, float %10499, float %10500, float %10501, float %10502, float %10503, float %10504, float %10505, float %10506, float %10507, float %10508, float %10509, float %10510, float %10511, float %10512, float %10513, float %10514, float %10515, float %10516, float %10517, float %10518, float %10519, float %10520, float %10521, float %10522, float %10523, float %10524, float %10525, float %10526, float %10527, float %10528, float %10529, float %10530, float %10531, float %10532, float %10533, float %10534, float %10535, float %10536, float %10537, float %10538, float %10539, float %10540, float %10541, float %10542, float %10543, float %10544, float %10545, float %10546, i32 %10443, i32 %10446, i32 %10449, i32 %10452, i64 %10482, i1 true) #3, !dbg !351 + %10548 = add i32 %8884, 4096, !dbg !351 + %10549 = lshr exact i32 %10548, 4, !dbg !351 + %10550 = and i32 %10549, 16383, !dbg !351 + %10551 = zext nneg i32 %10550 to i64, !dbg !351 + %10552 = or disjoint i64 %10551, 4611686293338849280, !dbg !351 + %10553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 0, !dbg !351 + %10554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 1, !dbg !351 + %10555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 2, !dbg !351 + %10556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 3, !dbg !351 + %10557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 4, !dbg !351 + %10558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 5, !dbg !351 + %10559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 6, !dbg !351 + %10560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 7, !dbg !351 + %10561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 8, !dbg !351 + %10562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 9, !dbg !351 + %10563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 10, !dbg !351 + %10564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 11, !dbg !351 + %10565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 12, !dbg !351 + %10566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 13, !dbg !351 + %10567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 14, !dbg !351 + %10568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 15, !dbg !351 + %10569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 16, !dbg !351 + %10570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 17, !dbg !351 + %10571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 18, !dbg !351 + %10572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 19, !dbg !351 + %10573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 20, !dbg !351 + %10574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 21, !dbg !351 + %10575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 22, !dbg !351 + %10576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 23, !dbg !351 + %10577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 24, !dbg !351 + %10578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 25, !dbg !351 + %10579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 26, !dbg !351 + %10580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 27, !dbg !351 + %10581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 28, !dbg !351 + %10582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 29, !dbg !351 + %10583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 30, !dbg !351 + %10584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 31, !dbg !351 + %10585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 32, !dbg !351 + %10586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 33, !dbg !351 + %10587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 34, !dbg !351 + %10588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 35, !dbg !351 + %10589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 36, !dbg !351 + %10590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 37, !dbg !351 + %10591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 38, !dbg !351 + %10592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 39, !dbg !351 + %10593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 40, !dbg !351 + %10594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 41, !dbg !351 + %10595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 42, !dbg !351 + %10596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 43, !dbg !351 + %10597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 44, !dbg !351 + %10598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 45, !dbg !351 + %10599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 46, !dbg !351 + %10600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 47, !dbg !351 + %10601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 48, !dbg !351 + %10602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 49, !dbg !351 + %10603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 50, !dbg !351 + %10604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 51, !dbg !351 + %10605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 52, !dbg !351 + %10606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 53, !dbg !351 + %10607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 54, !dbg !351 + %10608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 55, !dbg !351 + %10609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 56, !dbg !351 + %10610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 57, !dbg !351 + %10611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 58, !dbg !351 + %10612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 59, !dbg !351 + %10613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 60, !dbg !351 + %10614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 61, !dbg !351 + %10615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 62, !dbg !351 + %10616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10547, 63, !dbg !351 + %10617 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10553, float %10554, float %10555, float %10556, float %10557, float %10558, float %10559, float %10560, float %10561, float %10562, float %10563, float %10564, float %10565, float %10566, float %10567, float %10568, float %10569, float %10570, float %10571, float %10572, float %10573, float %10574, float %10575, float %10576, float %10577, float %10578, float %10579, float %10580, float %10581, float %10582, float %10583, float %10584, float %10585, float %10586, float %10587, float %10588, float %10589, float %10590, float %10591, float %10592, float %10593, float %10594, float %10595, float %10596, float %10597, float %10598, float %10599, float %10600, float %10601, float %10602, float %10603, float %10604, float %10605, float %10606, float %10607, float %10608, float %10609, float %10610, float %10611, float %10612, float %10613, float %10614, float %10615, float %10616, i32 %10455, i32 %10458, i32 %10461, i32 %10464, i64 %10552, i1 true) #3, !dbg !351 + %10618 = add i32 %8884, 6144, !dbg !351 + %10619 = lshr exact i32 %10618, 4, !dbg !351 + %10620 = and i32 %10619, 16383, !dbg !351 + %10621 = zext nneg i32 %10620 to i64, !dbg !351 + %10622 = or disjoint i64 %10621, 4611686293338849280, !dbg !351 + %10623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 0, !dbg !351 + %10624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 1, !dbg !351 + %10625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 2, !dbg !351 + %10626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 3, !dbg !351 + %10627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 4, !dbg !351 + %10628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 5, !dbg !351 + %10629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 6, !dbg !351 + %10630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 7, !dbg !351 + %10631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 8, !dbg !351 + %10632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 9, !dbg !351 + %10633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 10, !dbg !351 + %10634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 11, !dbg !351 + %10635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 12, !dbg !351 + %10636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 13, !dbg !351 + %10637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 14, !dbg !351 + %10638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 15, !dbg !351 + %10639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 16, !dbg !351 + %10640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 17, !dbg !351 + %10641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 18, !dbg !351 + %10642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 19, !dbg !351 + %10643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 20, !dbg !351 + %10644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 21, !dbg !351 + %10645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 22, !dbg !351 + %10646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 23, !dbg !351 + %10647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 24, !dbg !351 + %10648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 25, !dbg !351 + %10649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 26, !dbg !351 + %10650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 27, !dbg !351 + %10651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 28, !dbg !351 + %10652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 29, !dbg !351 + %10653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 30, !dbg !351 + %10654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 31, !dbg !351 + %10655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 32, !dbg !351 + %10656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 33, !dbg !351 + %10657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 34, !dbg !351 + %10658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 35, !dbg !351 + %10659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 36, !dbg !351 + %10660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 37, !dbg !351 + %10661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 38, !dbg !351 + %10662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 39, !dbg !351 + %10663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 40, !dbg !351 + %10664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 41, !dbg !351 + %10665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 42, !dbg !351 + %10666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 43, !dbg !351 + %10667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 44, !dbg !351 + %10668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 45, !dbg !351 + %10669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 46, !dbg !351 + %10670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 47, !dbg !351 + %10671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 48, !dbg !351 + %10672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 49, !dbg !351 + %10673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 50, !dbg !351 + %10674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 51, !dbg !351 + %10675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 52, !dbg !351 + %10676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 53, !dbg !351 + %10677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 54, !dbg !351 + %10678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 55, !dbg !351 + %10679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 56, !dbg !351 + %10680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 57, !dbg !351 + %10681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 58, !dbg !351 + %10682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 59, !dbg !351 + %10683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 60, !dbg !351 + %10684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 61, !dbg !351 + %10685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 62, !dbg !351 + %10686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10617, 63, !dbg !351 + %10687 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10623, float %10624, float %10625, float %10626, float %10627, float %10628, float %10629, float %10630, float %10631, float %10632, float %10633, float %10634, float %10635, float %10636, float %10637, float %10638, float %10639, float %10640, float %10641, float %10642, float %10643, float %10644, float %10645, float %10646, float %10647, float %10648, float %10649, float %10650, float %10651, float %10652, float %10653, float %10654, float %10655, float %10656, float %10657, float %10658, float %10659, float %10660, float %10661, float %10662, float %10663, float %10664, float %10665, float %10666, float %10667, float %10668, float %10669, float %10670, float %10671, float %10672, float %10673, float %10674, float %10675, float %10676, float %10677, float %10678, float %10679, float %10680, float %10681, float %10682, float %10683, float %10684, float %10685, float %10686, i32 %10467, i32 %10470, i32 %10473, i32 %10476, i64 %10622, i1 true) #3, !dbg !351 + %10688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 0, !dbg !351 + %10689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 1, !dbg !351 + %10690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 2, !dbg !351 + %10691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 3, !dbg !351 + %10692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 4, !dbg !351 + %10693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 5, !dbg !351 + %10694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 6, !dbg !351 + %10695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 7, !dbg !351 + %10696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 8, !dbg !351 + %10697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 9, !dbg !351 + %10698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 10, !dbg !351 + %10699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 11, !dbg !351 + %10700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 12, !dbg !351 + %10701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 13, !dbg !351 + %10702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 14, !dbg !351 + %10703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 15, !dbg !351 + %10704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 16, !dbg !351 + %10705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 17, !dbg !351 + %10706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 18, !dbg !351 + %10707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 19, !dbg !351 + %10708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 20, !dbg !351 + %10709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 21, !dbg !351 + %10710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 22, !dbg !351 + %10711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 23, !dbg !351 + %10712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 24, !dbg !351 + %10713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 25, !dbg !351 + %10714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 26, !dbg !351 + %10715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 27, !dbg !351 + %10716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 28, !dbg !351 + %10717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 29, !dbg !351 + %10718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 30, !dbg !351 + %10719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 31, !dbg !351 + %10720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 32, !dbg !351 + %10721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 33, !dbg !351 + %10722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 34, !dbg !351 + %10723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 35, !dbg !351 + %10724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 36, !dbg !351 + %10725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 37, !dbg !351 + %10726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 38, !dbg !351 + %10727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 39, !dbg !351 + %10728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 40, !dbg !351 + %10729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 41, !dbg !351 + %10730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 42, !dbg !351 + %10731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 43, !dbg !351 + %10732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 44, !dbg !351 + %10733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 45, !dbg !351 + %10734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 46, !dbg !351 + %10735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 47, !dbg !351 + %10736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 48, !dbg !351 + %10737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 49, !dbg !351 + %10738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 50, !dbg !351 + %10739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 51, !dbg !351 + %10740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 52, !dbg !351 + %10741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 53, !dbg !351 + %10742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 54, !dbg !351 + %10743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 55, !dbg !351 + %10744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 56, !dbg !351 + %10745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 57, !dbg !351 + %10746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 58, !dbg !351 + %10747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 59, !dbg !351 + %10748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 60, !dbg !351 + %10749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 61, !dbg !351 + %10750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 62, !dbg !351 + %10751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10687, 63, !dbg !351 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !351 + %10752 = add nuw nsw i32 %8783, 1, !dbg !334 + %10753 = lshr i32 %10752, 1, !dbg !352 + %10754 = zext nneg i32 %10753 to i64, !dbg !353 + %10755 = getelementptr i32, ptr addrspace(1) %5046, i64 %10754, !dbg !353 + %10756 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !354 + %10757 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %10755, i64 %10756, i1 %8785) #3, !dbg !354 + %10758 = add nuw nsw i32 %10753, 1, !dbg !355 + %10759 = icmp slt i32 %10758, %5050, !dbg !356 + %10760 = getelementptr i8, ptr addrspace(1) %10755, i64 4, !dbg !357 + %10761 = and i1 %8785, %10759, !dbg !334 + %10762 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !358 + %10763 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %10760, i64 %10762, i1 %10761) #3, !dbg !358 + %10764 = and i32 %8783, 1, !dbg !359 + %10765 = sub i32 %10763, %10757, !dbg !360 + %10766 = shl i32 %10765, 7, !dbg !361 + %10767 = add i32 %10766, -64, !dbg !362 + %10768 = xor i32 %10764, 1, !dbg !363 + %10769 = mul nuw nsw i32 %10767, %10768, !dbg !363 + %10770 = shl nuw nsw i32 %10764, 6, !dbg !364 + %10771 = add i32 %10769, %10770, !dbg !365 + %10772 = shl i32 %10771, 12, !dbg !366 + %10773 = sext i32 %10772 to i64, !dbg !332 + %10774 = getelementptr bfloat, ptr addrspace(1) %.pn5391832, i64 %10773, !dbg !332 + %10775 = getelementptr bfloat, ptr addrspace(1) %.pn5231833, i64 %10773, !dbg !332 + %10776 = getelementptr bfloat, ptr addrspace(1) %.pn5071834, i64 %10773, !dbg !332 + %10777 = getelementptr bfloat, ptr addrspace(1) %.pn4911835, i64 %10773, !dbg !332 + %10778 = shl i32 %10771, 7, !dbg !367 + %10779 = sext i32 %10778 to i64, !dbg !333 + %10780 = getelementptr bfloat, ptr addrspace(1) %.pn6031836, i64 %10779, !dbg !333 + %10781 = getelementptr bfloat, ptr addrspace(1) %.pn5871837, i64 %10779, !dbg !333 + %10782 = getelementptr bfloat, ptr addrspace(1) %.pn5711838, i64 %10779, !dbg !333 + %10783 = getelementptr bfloat, ptr addrspace(1) %.pn5551839, i64 %10779, !dbg !333 + %10784 = add i32 %10771, %.pn6351840, !dbg !368 + %10785 = add i32 %10771, %.pn6331841, !dbg !368 + %10786 = add i32 %10771, %.pn6311842, !dbg !368 + %10787 = add i32 %10771, %.pn6291843, !dbg !368 + %10788 = add i32 %10771, %.pn6271844, !dbg !368 + %10789 = add i32 %10771, %.pn6251845, !dbg !368 + %10790 = add i32 %10771, %.pn6231846, !dbg !368 + %10791 = add i32 %10771, %.pn6211847, !dbg !368 + %10792 = add i32 %10771, %.pn6191848, !dbg !368 + %10793 = add i32 %10771, %.pn6171849, !dbg !368 + %10794 = add i32 %10771, %.pn6151850, !dbg !368 + %10795 = add i32 %10771, %.pn6131851, !dbg !368 + %10796 = add i32 %10771, %.pn6111852, !dbg !368 + %10797 = add i32 %10771, %.pn6091853, !dbg !368 + %10798 = add i32 %10771, %.pn6071854, !dbg !368 + %10799 = add i32 %10771, %.pn6051855, !dbg !368 + %10800 = add i32 %10771, %8779, !dbg !368 + %10801 = add i32 %10771, %8780, !dbg !368 + %10802 = add i32 %10771, %8781, !dbg !368 + %10803 = add i32 %10771, %8782, !dbg !368 + %10804 = add i32 %10771, %8775, !dbg !368 + %10805 = add i32 %10771, %8776, !dbg !368 + %10806 = add i32 %10771, %8777, !dbg !368 + %10807 = add i32 %10771, %8778, !dbg !368 + %10808 = add i32 %8772, 1, !dbg !334 + %10809 = icmp sgt i32 %10808, 1, !dbg !334 + %10810 = select i1 %10809, i32 0, i32 %10808, !dbg !334 + %10811 = add i32 %8774, 1, !dbg !334 + %10812 = icmp sgt i32 %10811, 2, !dbg !334 + %10813 = select i1 %10812, i32 0, i32 %10811, !dbg !334 + %10814 = icmp slt i32 %10800, %17, !dbg !335 + %10815 = icmp slt i32 %10801, %17, !dbg !335 + %10816 = icmp slt i32 %10802, %17, !dbg !335 + %10817 = icmp slt i32 %10803, %17, !dbg !335 + %10818 = shl i32 %10813, 13, !dbg !326 + %10819 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %10818, !dbg !326 + %10820 = and i1 %8784, %10814, !dbg !334 + %10821 = and i1 %8784, %10815, !dbg !334 + %10822 = and i1 %8784, %10816, !dbg !334 + %10823 = and i1 %8784, %10817, !dbg !334 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !326 + %10824 = getelementptr inbounds nuw i8, ptr addrspace(3) %10819, i32 %5111, !dbg !326 + %10825 = select i1 %10820, i32 16, i32 0, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %10824, ptr addrspace(1) %10774, i32 %10825) #3, !dbg !326 + %10826 = getelementptr inbounds nuw i8, ptr addrspace(3) %10819, i32 %5114, !dbg !326 + %10827 = select i1 %10821, i32 16, i32 0, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10826, ptr addrspace(1) %10775, i32 %10827) #3, !dbg !326 + %10828 = getelementptr inbounds nuw i8, ptr addrspace(3) %10819, i32 %5117, !dbg !326 + %10829 = select i1 %10822, i32 16, i32 0, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10828, ptr addrspace(1) %10776, i32 %10829) #3, !dbg !326 + %10830 = getelementptr inbounds nuw i8, ptr addrspace(3) %10819, i32 %5120, !dbg !326 + %10831 = select i1 %10823, i32 16, i32 0, !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10830, ptr addrspace(1) %10777, i32 %10831) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + %10832 = icmp slt i32 %10784, %17, !dbg !369 + %10833 = icmp slt i32 %10785, %17, !dbg !369 + %10834 = icmp slt i32 %10786, %17, !dbg !369 + %10835 = icmp slt i32 %10787, %17, !dbg !369 + %10836 = icmp slt i32 %10788, %17, !dbg !369 + %10837 = icmp slt i32 %10789, %17, !dbg !369 + %10838 = icmp slt i32 %10790, %17, !dbg !369 + %10839 = icmp slt i32 %10791, %17, !dbg !369 + %10840 = icmp slt i32 %10792, %17, !dbg !369 + %10841 = icmp slt i32 %10793, %17, !dbg !369 + %10842 = icmp slt i32 %10794, %17, !dbg !369 + %10843 = icmp slt i32 %10795, %17, !dbg !369 + %10844 = icmp slt i32 %10796, %17, !dbg !369 + %10845 = icmp slt i32 %10797, %17, !dbg !369 + %10846 = icmp slt i32 %10798, %17, !dbg !369 + %10847 = icmp slt i32 %10799, %17, !dbg !369 + %10848 = sext i32 %10784 to i64, !dbg !327 + %10849 = getelementptr float, ptr addrspace(1) %5728, i64 %10848, !dbg !327 + %10850 = sext i32 %10785 to i64, !dbg !327 + %10851 = getelementptr float, ptr addrspace(1) %5728, i64 %10850, !dbg !327 + %10852 = sext i32 %10786 to i64, !dbg !327 + %10853 = getelementptr float, ptr addrspace(1) %5728, i64 %10852, !dbg !327 + %10854 = sext i32 %10787 to i64, !dbg !327 + %10855 = getelementptr float, ptr addrspace(1) %5728, i64 %10854, !dbg !327 + %10856 = sext i32 %10788 to i64, !dbg !327 + %10857 = getelementptr float, ptr addrspace(1) %5728, i64 %10856, !dbg !327 + %10858 = sext i32 %10789 to i64, !dbg !327 + %10859 = getelementptr float, ptr addrspace(1) %5728, i64 %10858, !dbg !327 + %10860 = sext i32 %10790 to i64, !dbg !327 + %10861 = getelementptr float, ptr addrspace(1) %5728, i64 %10860, !dbg !327 + %10862 = sext i32 %10791 to i64, !dbg !327 + %10863 = getelementptr float, ptr addrspace(1) %5728, i64 %10862, !dbg !327 + %10864 = sext i32 %10792 to i64, !dbg !327 + %10865 = getelementptr float, ptr addrspace(1) %5728, i64 %10864, !dbg !327 + %10866 = sext i32 %10793 to i64, !dbg !327 + %10867 = getelementptr float, ptr addrspace(1) %5728, i64 %10866, !dbg !327 + %10868 = sext i32 %10794 to i64, !dbg !327 + %10869 = getelementptr float, ptr addrspace(1) %5728, i64 %10868, !dbg !327 + %10870 = sext i32 %10795 to i64, !dbg !327 + %10871 = getelementptr float, ptr addrspace(1) %5728, i64 %10870, !dbg !327 + %10872 = sext i32 %10796 to i64, !dbg !327 + %10873 = getelementptr float, ptr addrspace(1) %5728, i64 %10872, !dbg !327 + %10874 = sext i32 %10797 to i64, !dbg !327 + %10875 = getelementptr float, ptr addrspace(1) %5728, i64 %10874, !dbg !327 + %10876 = sext i32 %10798 to i64, !dbg !327 + %10877 = getelementptr float, ptr addrspace(1) %5728, i64 %10876, !dbg !327 + %10878 = sext i32 %10799 to i64, !dbg !327 + %10879 = getelementptr float, ptr addrspace(1) %5728, i64 %10878, !dbg !327 + %10880 = shl i32 %10810, 6, !dbg !328 + %10881 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %10880, !dbg !328 + %10882 = and i1 %8784, %10832, !dbg !334 + %10883 = and i1 %8784, %10833, !dbg !334 + %10884 = and i1 %8784, %10834, !dbg !334 + %10885 = and i1 %8784, %10835, !dbg !334 + %10886 = and i1 %8784, %10836, !dbg !334 + %10887 = and i1 %8784, %10837, !dbg !334 + %10888 = and i1 %8784, %10838, !dbg !334 + %10889 = and i1 %8784, %10839, !dbg !334 + %10890 = and i1 %8784, %10840, !dbg !334 + %10891 = and i1 %8784, %10841, !dbg !334 + %10892 = and i1 %8784, %10842, !dbg !334 + %10893 = and i1 %8784, %10843, !dbg !334 + %10894 = and i1 %8784, %10844, !dbg !334 + %10895 = and i1 %8784, %10845, !dbg !334 + %10896 = and i1 %8784, %10846, !dbg !334 + %10897 = and i1 %8784, %10847, !dbg !334 + %10898 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5189, !dbg !328 + %10899 = select i1 %10882, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %10898, ptr addrspace(1) %10849, i32 %10899, i1 %5188) #3, !dbg !328 + %10900 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5192, !dbg !328 + %10901 = select i1 %10883, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10900, ptr addrspace(1) %10851, i32 %10901, i1 %5188) #3, !dbg !328 + %10902 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5195, !dbg !328 + %10903 = select i1 %10884, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10902, ptr addrspace(1) %10853, i32 %10903, i1 %5188) #3, !dbg !328 + %10904 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5198, !dbg !328 + %10905 = select i1 %10885, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10904, ptr addrspace(1) %10855, i32 %10905, i1 %5188) #3, !dbg !328 + %10906 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5201, !dbg !328 + %10907 = select i1 %10886, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10906, ptr addrspace(1) %10857, i32 %10907, i1 %5188) #3, !dbg !328 + %10908 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5204, !dbg !328 + %10909 = select i1 %10887, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10908, ptr addrspace(1) %10859, i32 %10909, i1 %5188) #3, !dbg !328 + %10910 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5207, !dbg !328 + %10911 = select i1 %10888, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10910, ptr addrspace(1) %10861, i32 %10911, i1 %5188) #3, !dbg !328 + %10912 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5210, !dbg !328 + %10913 = select i1 %10889, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10912, ptr addrspace(1) %10863, i32 %10913, i1 %5188) #3, !dbg !328 + %10914 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5213, !dbg !328 + %10915 = select i1 %10890, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10914, ptr addrspace(1) %10865, i32 %10915, i1 %5188) #3, !dbg !328 + %10916 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5216, !dbg !328 + %10917 = select i1 %10891, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10916, ptr addrspace(1) %10867, i32 %10917, i1 %5188) #3, !dbg !328 + %10918 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5219, !dbg !328 + %10919 = select i1 %10892, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10918, ptr addrspace(1) %10869, i32 %10919, i1 %5188) #3, !dbg !328 + %10920 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5222, !dbg !328 + %10921 = select i1 %10893, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10920, ptr addrspace(1) %10871, i32 %10921, i1 %5188) #3, !dbg !328 + %10922 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5225, !dbg !328 + %10923 = select i1 %10894, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10922, ptr addrspace(1) %10873, i32 %10923, i1 %5188) #3, !dbg !328 + %10924 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5228, !dbg !328 + %10925 = select i1 %10895, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10924, ptr addrspace(1) %10875, i32 %10925, i1 %5188) #3, !dbg !328 + %10926 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5231, !dbg !328 + %10927 = select i1 %10896, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10926, ptr addrspace(1) %10877, i32 %10927, i1 %5188) #3, !dbg !328 + %10928 = getelementptr inbounds nuw i8, ptr addrspace(3) %10881, i32 %5234, !dbg !328 + %10929 = select i1 %10897, i32 4, i32 0, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10928, ptr addrspace(1) %10879, i32 %10929, i1 %5188) #3, !dbg !328 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !328 + %10930 = icmp slt i32 %10804, %17, !dbg !370 + %10931 = icmp slt i32 %10805, %17, !dbg !370 + %10932 = icmp slt i32 %10806, %17, !dbg !370 + %10933 = icmp slt i32 %10807, %17, !dbg !370 + %10934 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %10818, !dbg !329 + %10935 = and i1 %8784, %10930, !dbg !334 + %10936 = and i1 %8784, %10931, !dbg !334 + %10937 = and i1 %8784, %10932, !dbg !334 + %10938 = and i1 %8784, %10933, !dbg !334 + %10939 = getelementptr inbounds nuw i8, ptr addrspace(3) %10934, i32 %5111, !dbg !329 + %10940 = select i1 %10935, i32 16, i32 0, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %10939, ptr addrspace(1) %10780, i32 %10940) #3, !dbg !329 + %10941 = getelementptr inbounds nuw i8, ptr addrspace(3) %10934, i32 %5114, !dbg !329 + %10942 = select i1 %10936, i32 16, i32 0, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10941, ptr addrspace(1) %10781, i32 %10942) #3, !dbg !329 + %10943 = getelementptr inbounds nuw i8, ptr addrspace(3) %10934, i32 %5117, !dbg !329 + %10944 = select i1 %10937, i32 16, i32 0, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10943, ptr addrspace(1) %10782, i32 %10944) #3, !dbg !329 + %10945 = getelementptr inbounds nuw i8, ptr addrspace(3) %10934, i32 %5120, !dbg !329 + %10946 = select i1 %10938, i32 16, i32 0, !dbg !329 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10945, ptr addrspace(1) %10783, i32 %10946) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + %10947 = getelementptr float, ptr addrspace(1) %5729, i64 %10848, !dbg !330 + %10948 = getelementptr float, ptr addrspace(1) %5729, i64 %10850, !dbg !330 + %10949 = getelementptr float, ptr addrspace(1) %5729, i64 %10852, !dbg !330 + %10950 = getelementptr float, ptr addrspace(1) %5729, i64 %10854, !dbg !330 + %10951 = getelementptr float, ptr addrspace(1) %5729, i64 %10856, !dbg !330 + %10952 = getelementptr float, ptr addrspace(1) %5729, i64 %10858, !dbg !330 + %10953 = getelementptr float, ptr addrspace(1) %5729, i64 %10860, !dbg !330 + %10954 = getelementptr float, ptr addrspace(1) %5729, i64 %10862, !dbg !330 + %10955 = getelementptr float, ptr addrspace(1) %5729, i64 %10864, !dbg !330 + %10956 = getelementptr float, ptr addrspace(1) %5729, i64 %10866, !dbg !330 + %10957 = getelementptr float, ptr addrspace(1) %5729, i64 %10868, !dbg !330 + %10958 = getelementptr float, ptr addrspace(1) %5729, i64 %10870, !dbg !330 + %10959 = getelementptr float, ptr addrspace(1) %5729, i64 %10872, !dbg !330 + %10960 = getelementptr float, ptr addrspace(1) %5729, i64 %10874, !dbg !330 + %10961 = getelementptr float, ptr addrspace(1) %5729, i64 %10876, !dbg !330 + %10962 = getelementptr float, ptr addrspace(1) %5729, i64 %10878, !dbg !330 + %10963 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %10880, !dbg !331 + %10964 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5189, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %10964, ptr addrspace(1) %10947, i32 %10899, i1 %5188) #3, !dbg !331 + %10965 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5192, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10965, ptr addrspace(1) %10948, i32 %10901, i1 %5188) #3, !dbg !331 + %10966 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5195, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10966, ptr addrspace(1) %10949, i32 %10903, i1 %5188) #3, !dbg !331 + %10967 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5198, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10967, ptr addrspace(1) %10950, i32 %10905, i1 %5188) #3, !dbg !331 + %10968 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5201, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10968, ptr addrspace(1) %10951, i32 %10907, i1 %5188) #3, !dbg !331 + %10969 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5204, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10969, ptr addrspace(1) %10952, i32 %10909, i1 %5188) #3, !dbg !331 + %10970 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5207, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10970, ptr addrspace(1) %10953, i32 %10911, i1 %5188) #3, !dbg !331 + %10971 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5210, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10971, ptr addrspace(1) %10954, i32 %10913, i1 %5188) #3, !dbg !331 + %10972 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5213, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10972, ptr addrspace(1) %10955, i32 %10915, i1 %5188) #3, !dbg !331 + %10973 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5216, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10973, ptr addrspace(1) %10956, i32 %10917, i1 %5188) #3, !dbg !331 + %10974 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5219, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10974, ptr addrspace(1) %10957, i32 %10919, i1 %5188) #3, !dbg !331 + %10975 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5222, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10975, ptr addrspace(1) %10958, i32 %10921, i1 %5188) #3, !dbg !331 + %10976 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5225, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10976, ptr addrspace(1) %10959, i32 %10923, i1 %5188) #3, !dbg !331 + %10977 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5228, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10977, ptr addrspace(1) %10960, i32 %10925, i1 %5188) #3, !dbg !331 + %10978 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5231, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10978, ptr addrspace(1) %10961, i32 %10927, i1 %5188) #3, !dbg !331 + %10979 = getelementptr inbounds nuw i8, ptr addrspace(3) %10963, i32 %5234, !dbg !331 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10979, ptr addrspace(1) %10962, i32 %10929, i1 %5188) #3, !dbg !331 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !331 + %exitcond2268.not = icmp eq i32 %10752, %smax2267, !dbg !334 + br i1 %exitcond2268.not, label %._crit_edge1874, label %.lr.ph1873, !dbg !334 + +._crit_edge1874: ; preds = %__nv_exp2f.exit1321, %._crit_edge1701 + %.pn347.lcssa = phi float [ %8619, %._crit_edge1701 ], [ %10688, %__nv_exp2f.exit1321 ] + %.pn345.lcssa = phi float [ %8620, %._crit_edge1701 ], [ %10689, %__nv_exp2f.exit1321 ] + %.pn343.lcssa = phi float [ %8621, %._crit_edge1701 ], [ %10690, %__nv_exp2f.exit1321 ] + %.pn341.lcssa = phi float [ %8622, %._crit_edge1701 ], [ %10691, %__nv_exp2f.exit1321 ] + %.pn339.lcssa = phi float [ %8623, %._crit_edge1701 ], [ %10692, %__nv_exp2f.exit1321 ] + %.pn337.lcssa = phi float [ %8624, %._crit_edge1701 ], [ %10693, %__nv_exp2f.exit1321 ] + %.pn335.lcssa = phi float [ %8625, %._crit_edge1701 ], [ %10694, %__nv_exp2f.exit1321 ] + %.pn333.lcssa = phi float [ %8626, %._crit_edge1701 ], [ %10695, %__nv_exp2f.exit1321 ] + %.pn331.lcssa = phi float [ %8627, %._crit_edge1701 ], [ %10696, %__nv_exp2f.exit1321 ] + %.pn329.lcssa = phi float [ %8628, %._crit_edge1701 ], [ %10697, %__nv_exp2f.exit1321 ] + %.pn327.lcssa = phi float [ %8629, %._crit_edge1701 ], [ %10698, %__nv_exp2f.exit1321 ] + %.pn325.lcssa = phi float [ %8630, %._crit_edge1701 ], [ %10699, %__nv_exp2f.exit1321 ] + %.pn323.lcssa = phi float [ %8631, %._crit_edge1701 ], [ %10700, %__nv_exp2f.exit1321 ] + %.pn321.lcssa = phi float [ %8632, %._crit_edge1701 ], [ %10701, %__nv_exp2f.exit1321 ] + %.pn319.lcssa = phi float [ %8633, %._crit_edge1701 ], [ %10702, %__nv_exp2f.exit1321 ] + %.pn317.lcssa = phi float [ %8634, %._crit_edge1701 ], [ %10703, %__nv_exp2f.exit1321 ] + %.pn315.lcssa = phi float [ %8635, %._crit_edge1701 ], [ %10704, %__nv_exp2f.exit1321 ] + %.pn313.lcssa = phi float [ %8636, %._crit_edge1701 ], [ %10705, %__nv_exp2f.exit1321 ] + %.pn311.lcssa = phi float [ %8637, %._crit_edge1701 ], [ %10706, %__nv_exp2f.exit1321 ] + %.pn309.lcssa = phi float [ %8638, %._crit_edge1701 ], [ %10707, %__nv_exp2f.exit1321 ] + %.pn307.lcssa = phi float [ %8639, %._crit_edge1701 ], [ %10708, %__nv_exp2f.exit1321 ] + %.pn305.lcssa = phi float [ %8640, %._crit_edge1701 ], [ %10709, %__nv_exp2f.exit1321 ] + %.pn303.lcssa = phi float [ %8641, %._crit_edge1701 ], [ %10710, %__nv_exp2f.exit1321 ] + %.pn301.lcssa = phi float [ %8642, %._crit_edge1701 ], [ %10711, %__nv_exp2f.exit1321 ] + %.pn299.lcssa = phi float [ %8643, %._crit_edge1701 ], [ %10712, %__nv_exp2f.exit1321 ] + %.pn297.lcssa = phi float [ %8644, %._crit_edge1701 ], [ %10713, %__nv_exp2f.exit1321 ] + %.pn295.lcssa = phi float [ %8645, %._crit_edge1701 ], [ %10714, %__nv_exp2f.exit1321 ] + %.pn293.lcssa = phi float [ %8646, %._crit_edge1701 ], [ %10715, %__nv_exp2f.exit1321 ] + %.pn291.lcssa = phi float [ %8647, %._crit_edge1701 ], [ %10716, %__nv_exp2f.exit1321 ] + %.pn289.lcssa = phi float [ %8648, %._crit_edge1701 ], [ %10717, %__nv_exp2f.exit1321 ] + %.pn287.lcssa = phi float [ %8649, %._crit_edge1701 ], [ %10718, %__nv_exp2f.exit1321 ] + %.pn285.lcssa = phi float [ %8650, %._crit_edge1701 ], [ %10719, %__nv_exp2f.exit1321 ] + %.pn283.lcssa = phi float [ %8651, %._crit_edge1701 ], [ %10720, %__nv_exp2f.exit1321 ] + %.pn281.lcssa = phi float [ %8652, %._crit_edge1701 ], [ %10721, %__nv_exp2f.exit1321 ] + %.pn279.lcssa = phi float [ %8653, %._crit_edge1701 ], [ %10722, %__nv_exp2f.exit1321 ] + %.pn277.lcssa = phi float [ %8654, %._crit_edge1701 ], [ %10723, %__nv_exp2f.exit1321 ] + %.pn275.lcssa = phi float [ %8655, %._crit_edge1701 ], [ %10724, %__nv_exp2f.exit1321 ] + %.pn273.lcssa = phi float [ %8656, %._crit_edge1701 ], [ %10725, %__nv_exp2f.exit1321 ] + %.pn271.lcssa = phi float [ %8657, %._crit_edge1701 ], [ %10726, %__nv_exp2f.exit1321 ] + %.pn269.lcssa = phi float [ %8658, %._crit_edge1701 ], [ %10727, %__nv_exp2f.exit1321 ] + %.pn267.lcssa = phi float [ %8659, %._crit_edge1701 ], [ %10728, %__nv_exp2f.exit1321 ] + %.pn265.lcssa = phi float [ %8660, %._crit_edge1701 ], [ %10729, %__nv_exp2f.exit1321 ] + %.pn263.lcssa = phi float [ %8661, %._crit_edge1701 ], [ %10730, %__nv_exp2f.exit1321 ] + %.pn261.lcssa = phi float [ %8662, %._crit_edge1701 ], [ %10731, %__nv_exp2f.exit1321 ] + %.pn259.lcssa = phi float [ %8663, %._crit_edge1701 ], [ %10732, %__nv_exp2f.exit1321 ] + %.pn257.lcssa = phi float [ %8664, %._crit_edge1701 ], [ %10733, %__nv_exp2f.exit1321 ] + %.pn255.lcssa = phi float [ %8665, %._crit_edge1701 ], [ %10734, %__nv_exp2f.exit1321 ] + %.pn253.lcssa = phi float [ %8666, %._crit_edge1701 ], [ %10735, %__nv_exp2f.exit1321 ] + %.pn251.lcssa = phi float [ %8667, %._crit_edge1701 ], [ %10736, %__nv_exp2f.exit1321 ] + %.pn249.lcssa = phi float [ %8668, %._crit_edge1701 ], [ %10737, %__nv_exp2f.exit1321 ] + %.pn247.lcssa = phi float [ %8669, %._crit_edge1701 ], [ %10738, %__nv_exp2f.exit1321 ] + %.pn245.lcssa = phi float [ %8670, %._crit_edge1701 ], [ %10739, %__nv_exp2f.exit1321 ] + %.pn243.lcssa = phi float [ %8671, %._crit_edge1701 ], [ %10740, %__nv_exp2f.exit1321 ] + %.pn241.lcssa = phi float [ %8672, %._crit_edge1701 ], [ %10741, %__nv_exp2f.exit1321 ] + %.pn239.lcssa = phi float [ %8673, %._crit_edge1701 ], [ %10742, %__nv_exp2f.exit1321 ] + %.pn237.lcssa = phi float [ %8674, %._crit_edge1701 ], [ %10743, %__nv_exp2f.exit1321 ] + %.pn235.lcssa = phi float [ %8675, %._crit_edge1701 ], [ %10744, %__nv_exp2f.exit1321 ] + %.pn233.lcssa = phi float [ %8676, %._crit_edge1701 ], [ %10745, %__nv_exp2f.exit1321 ] + %.pn231.lcssa = phi float [ %8677, %._crit_edge1701 ], [ %10746, %__nv_exp2f.exit1321 ] + %.pn229.lcssa = phi float [ %8678, %._crit_edge1701 ], [ %10747, %__nv_exp2f.exit1321 ] + %.pn227.lcssa = phi float [ %8679, %._crit_edge1701 ], [ %10748, %__nv_exp2f.exit1321 ] + %.pn225.lcssa = phi float [ %8680, %._crit_edge1701 ], [ %10749, %__nv_exp2f.exit1321 ] + %.pn223.lcssa = phi float [ %8681, %._crit_edge1701 ], [ %10750, %__nv_exp2f.exit1321 ] + %.pn221.lcssa = phi float [ %8682, %._crit_edge1701 ], [ %10751, %__nv_exp2f.exit1321 ] + %.pn475.lcssa = phi float [ %8555, %._crit_edge1701 ], [ %9832, %__nv_exp2f.exit1321 ] + %.pn473.lcssa = phi float [ %8556, %._crit_edge1701 ], [ %9833, %__nv_exp2f.exit1321 ] + %.pn471.lcssa = phi float [ %8557, %._crit_edge1701 ], [ %9834, %__nv_exp2f.exit1321 ] + %.pn469.lcssa = phi float [ %8558, %._crit_edge1701 ], [ %9835, %__nv_exp2f.exit1321 ] + %.pn467.lcssa = phi float [ %8559, %._crit_edge1701 ], [ %9836, %__nv_exp2f.exit1321 ] + %.pn465.lcssa = phi float [ %8560, %._crit_edge1701 ], [ %9837, %__nv_exp2f.exit1321 ] + %.pn463.lcssa = phi float [ %8561, %._crit_edge1701 ], [ %9838, %__nv_exp2f.exit1321 ] + %.pn461.lcssa = phi float [ %8562, %._crit_edge1701 ], [ %9839, %__nv_exp2f.exit1321 ] + %.pn459.lcssa = phi float [ %8563, %._crit_edge1701 ], [ %9840, %__nv_exp2f.exit1321 ] + %.pn457.lcssa = phi float [ %8564, %._crit_edge1701 ], [ %9841, %__nv_exp2f.exit1321 ] + %.pn455.lcssa = phi float [ %8565, %._crit_edge1701 ], [ %9842, %__nv_exp2f.exit1321 ] + %.pn453.lcssa = phi float [ %8566, %._crit_edge1701 ], [ %9843, %__nv_exp2f.exit1321 ] + %.pn451.lcssa = phi float [ %8567, %._crit_edge1701 ], [ %9844, %__nv_exp2f.exit1321 ] + %.pn449.lcssa = phi float [ %8568, %._crit_edge1701 ], [ %9845, %__nv_exp2f.exit1321 ] + %.pn447.lcssa = phi float [ %8569, %._crit_edge1701 ], [ %9846, %__nv_exp2f.exit1321 ] + %.pn445.lcssa = phi float [ %8570, %._crit_edge1701 ], [ %9847, %__nv_exp2f.exit1321 ] + %.pn443.lcssa = phi float [ %8571, %._crit_edge1701 ], [ %9848, %__nv_exp2f.exit1321 ] + %.pn441.lcssa = phi float [ %8572, %._crit_edge1701 ], [ %9849, %__nv_exp2f.exit1321 ] + %.pn439.lcssa = phi float [ %8573, %._crit_edge1701 ], [ %9850, %__nv_exp2f.exit1321 ] + %.pn437.lcssa = phi float [ %8574, %._crit_edge1701 ], [ %9851, %__nv_exp2f.exit1321 ] + %.pn435.lcssa = phi float [ %8575, %._crit_edge1701 ], [ %9852, %__nv_exp2f.exit1321 ] + %.pn433.lcssa = phi float [ %8576, %._crit_edge1701 ], [ %9853, %__nv_exp2f.exit1321 ] + %.pn431.lcssa = phi float [ %8577, %._crit_edge1701 ], [ %9854, %__nv_exp2f.exit1321 ] + %.pn429.lcssa = phi float [ %8578, %._crit_edge1701 ], [ %9855, %__nv_exp2f.exit1321 ] + %.pn427.lcssa = phi float [ %8579, %._crit_edge1701 ], [ %9856, %__nv_exp2f.exit1321 ] + %.pn425.lcssa = phi float [ %8580, %._crit_edge1701 ], [ %9857, %__nv_exp2f.exit1321 ] + %.pn423.lcssa = phi float [ %8581, %._crit_edge1701 ], [ %9858, %__nv_exp2f.exit1321 ] + %.pn421.lcssa = phi float [ %8582, %._crit_edge1701 ], [ %9859, %__nv_exp2f.exit1321 ] + %.pn419.lcssa = phi float [ %8583, %._crit_edge1701 ], [ %9860, %__nv_exp2f.exit1321 ] + %.pn417.lcssa = phi float [ %8584, %._crit_edge1701 ], [ %9861, %__nv_exp2f.exit1321 ] + %.pn415.lcssa = phi float [ %8585, %._crit_edge1701 ], [ %9862, %__nv_exp2f.exit1321 ] + %.pn413.lcssa = phi float [ %8586, %._crit_edge1701 ], [ %9863, %__nv_exp2f.exit1321 ] + %.pn411.lcssa = phi float [ %8587, %._crit_edge1701 ], [ %9864, %__nv_exp2f.exit1321 ] + %.pn409.lcssa = phi float [ %8588, %._crit_edge1701 ], [ %9865, %__nv_exp2f.exit1321 ] + %.pn407.lcssa = phi float [ %8589, %._crit_edge1701 ], [ %9866, %__nv_exp2f.exit1321 ] + %.pn405.lcssa = phi float [ %8590, %._crit_edge1701 ], [ %9867, %__nv_exp2f.exit1321 ] + %.pn403.lcssa = phi float [ %8591, %._crit_edge1701 ], [ %9868, %__nv_exp2f.exit1321 ] + %.pn401.lcssa = phi float [ %8592, %._crit_edge1701 ], [ %9869, %__nv_exp2f.exit1321 ] + %.pn399.lcssa = phi float [ %8593, %._crit_edge1701 ], [ %9870, %__nv_exp2f.exit1321 ] + %.pn397.lcssa = phi float [ %8594, %._crit_edge1701 ], [ %9871, %__nv_exp2f.exit1321 ] + %.pn395.lcssa = phi float [ %8595, %._crit_edge1701 ], [ %9872, %__nv_exp2f.exit1321 ] + %.pn393.lcssa = phi float [ %8596, %._crit_edge1701 ], [ %9873, %__nv_exp2f.exit1321 ] + %.pn391.lcssa = phi float [ %8597, %._crit_edge1701 ], [ %9874, %__nv_exp2f.exit1321 ] + %.pn389.lcssa = phi float [ %8598, %._crit_edge1701 ], [ %9875, %__nv_exp2f.exit1321 ] + %.pn387.lcssa = phi float [ %8599, %._crit_edge1701 ], [ %9876, %__nv_exp2f.exit1321 ] + %.pn385.lcssa = phi float [ %8600, %._crit_edge1701 ], [ %9877, %__nv_exp2f.exit1321 ] + %.pn383.lcssa = phi float [ %8601, %._crit_edge1701 ], [ %9878, %__nv_exp2f.exit1321 ] + %.pn381.lcssa = phi float [ %8602, %._crit_edge1701 ], [ %9879, %__nv_exp2f.exit1321 ] + %.pn379.lcssa = phi float [ %8603, %._crit_edge1701 ], [ %9880, %__nv_exp2f.exit1321 ] + %.pn377.lcssa = phi float [ %8604, %._crit_edge1701 ], [ %9881, %__nv_exp2f.exit1321 ] + %.pn375.lcssa = phi float [ %8605, %._crit_edge1701 ], [ %9882, %__nv_exp2f.exit1321 ] + %.pn373.lcssa = phi float [ %8606, %._crit_edge1701 ], [ %9883, %__nv_exp2f.exit1321 ] + %.pn371.lcssa = phi float [ %8607, %._crit_edge1701 ], [ %9884, %__nv_exp2f.exit1321 ] + %.pn369.lcssa = phi float [ %8608, %._crit_edge1701 ], [ %9885, %__nv_exp2f.exit1321 ] + %.pn367.lcssa = phi float [ %8609, %._crit_edge1701 ], [ %9886, %__nv_exp2f.exit1321 ] + %.pn365.lcssa = phi float [ %8610, %._crit_edge1701 ], [ %9887, %__nv_exp2f.exit1321 ] + %.pn363.lcssa = phi float [ %8611, %._crit_edge1701 ], [ %9888, %__nv_exp2f.exit1321 ] + %.pn361.lcssa = phi float [ %8612, %._crit_edge1701 ], [ %9889, %__nv_exp2f.exit1321 ] + %.pn359.lcssa = phi float [ %8613, %._crit_edge1701 ], [ %9890, %__nv_exp2f.exit1321 ] + %.pn357.lcssa = phi float [ %8614, %._crit_edge1701 ], [ %9891, %__nv_exp2f.exit1321 ] + %.pn355.lcssa = phi float [ %8615, %._crit_edge1701 ], [ %9892, %__nv_exp2f.exit1321 ] + %.pn353.lcssa = phi float [ %8616, %._crit_edge1701 ], [ %9893, %__nv_exp2f.exit1321 ] + %.pn351.lcssa = phi float [ %8617, %._crit_edge1701 ], [ %9894, %__nv_exp2f.exit1321 ] + %.pn349.lcssa = phi float [ %8618, %._crit_edge1701 ], [ %9895, %__nv_exp2f.exit1321 ] + %10980 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %.pn475.lcssa, float %.pn473.lcssa, float %.pn471.lcssa, float %.pn469.lcssa, float %.pn467.lcssa, float %.pn465.lcssa, float %.pn463.lcssa, float %.pn461.lcssa, float %.pn459.lcssa, float %.pn457.lcssa, float %.pn455.lcssa, float %.pn453.lcssa, float %.pn451.lcssa, float %.pn449.lcssa, float %.pn447.lcssa, float %.pn445.lcssa, float %.pn443.lcssa, float %.pn441.lcssa, float %.pn439.lcssa, float %.pn437.lcssa, float %.pn435.lcssa, float %.pn433.lcssa, float %.pn431.lcssa, float %.pn429.lcssa, float %.pn427.lcssa, float %.pn425.lcssa, float %.pn423.lcssa, float %.pn421.lcssa, float %.pn419.lcssa, float %.pn417.lcssa, float %.pn415.lcssa, float %.pn413.lcssa, float %.pn411.lcssa, float %.pn409.lcssa, float %.pn407.lcssa, float %.pn405.lcssa, float %.pn403.lcssa, float %.pn401.lcssa, float %.pn399.lcssa, float %.pn397.lcssa, float %.pn395.lcssa, float %.pn393.lcssa, float %.pn391.lcssa, float %.pn389.lcssa, float %.pn387.lcssa, float %.pn385.lcssa, float %.pn383.lcssa, float %.pn381.lcssa, float %.pn379.lcssa, float %.pn377.lcssa, float %.pn375.lcssa, float %.pn373.lcssa, float %.pn371.lcssa, float %.pn369.lcssa, float %.pn367.lcssa, float %.pn365.lcssa, float %.pn363.lcssa, float %.pn361.lcssa, float %.pn359.lcssa, float %.pn357.lcssa, float %.pn355.lcssa, float %.pn353.lcssa, float %.pn351.lcssa, float %.pn349.lcssa, float %.pn347.lcssa, float %.pn345.lcssa, float %.pn343.lcssa, float %.pn341.lcssa, float %.pn339.lcssa, float %.pn337.lcssa, float %.pn335.lcssa, float %.pn333.lcssa, float %.pn331.lcssa, float %.pn329.lcssa, float %.pn327.lcssa, float %.pn325.lcssa, float %.pn323.lcssa, float %.pn321.lcssa, float %.pn319.lcssa, float %.pn317.lcssa, float %.pn315.lcssa, float %.pn313.lcssa, float %.pn311.lcssa, float %.pn309.lcssa, float %.pn307.lcssa, float %.pn305.lcssa, float %.pn303.lcssa, float %.pn301.lcssa, float %.pn299.lcssa, float %.pn297.lcssa, float %.pn295.lcssa, float %.pn293.lcssa, float %.pn291.lcssa, float %.pn289.lcssa, float %.pn287.lcssa, float %.pn285.lcssa, float %.pn283.lcssa, float %.pn281.lcssa, float %.pn279.lcssa, float %.pn277.lcssa, float %.pn275.lcssa, float %.pn273.lcssa, float %.pn271.lcssa, float %.pn269.lcssa, float %.pn267.lcssa, float %.pn265.lcssa, float %.pn263.lcssa, float %.pn261.lcssa, float %.pn259.lcssa, float %.pn257.lcssa, float %.pn255.lcssa, float %.pn253.lcssa, float %.pn251.lcssa, float %.pn249.lcssa, float %.pn247.lcssa, float %.pn245.lcssa, float %.pn243.lcssa, float %.pn241.lcssa, float %.pn239.lcssa, float %.pn237.lcssa, float %.pn235.lcssa, float %.pn233.lcssa, float %.pn231.lcssa, float %.pn229.lcssa, float %.pn227.lcssa, float %.pn225.lcssa, float %.pn223.lcssa, float %.pn221.lcssa) #3, !dbg !334 + %10981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 0, !dbg !334 + %10982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 1, !dbg !334 + %10983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 2, !dbg !334 + %10984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 3, !dbg !334 + %10985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 4, !dbg !334 + %10986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 5, !dbg !334 + %10987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 6, !dbg !334 + %10988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 7, !dbg !334 + %10989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 8, !dbg !334 + %10990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 9, !dbg !334 + %10991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 10, !dbg !334 + %10992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 11, !dbg !334 + %10993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 12, !dbg !334 + %10994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 13, !dbg !334 + %10995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 14, !dbg !334 + %10996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 15, !dbg !334 + %10997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 16, !dbg !334 + %10998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 17, !dbg !334 + %10999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 18, !dbg !334 + %11000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 19, !dbg !334 + %11001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 20, !dbg !334 + %11002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 21, !dbg !334 + %11003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 22, !dbg !334 + %11004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 23, !dbg !334 + %11005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 24, !dbg !334 + %11006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 25, !dbg !334 + %11007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 26, !dbg !334 + %11008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 27, !dbg !334 + %11009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 28, !dbg !334 + %11010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 29, !dbg !334 + %11011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 30, !dbg !334 + %11012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 31, !dbg !334 + %11013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 32, !dbg !334 + %11014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 33, !dbg !334 + %11015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 34, !dbg !334 + %11016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 35, !dbg !334 + %11017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 36, !dbg !334 + %11018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 37, !dbg !334 + %11019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 38, !dbg !334 + %11020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 39, !dbg !334 + %11021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 40, !dbg !334 + %11022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 41, !dbg !334 + %11023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 42, !dbg !334 + %11024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 43, !dbg !334 + %11025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 44, !dbg !334 + %11026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 45, !dbg !334 + %11027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 46, !dbg !334 + %11028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 47, !dbg !334 + %11029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 48, !dbg !334 + %11030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 49, !dbg !334 + %11031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 50, !dbg !334 + %11032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 51, !dbg !334 + %11033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 52, !dbg !334 + %11034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 53, !dbg !334 + %11035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 54, !dbg !334 + %11036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 55, !dbg !334 + %11037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 56, !dbg !334 + %11038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 57, !dbg !334 + %11039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 58, !dbg !334 + %11040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 59, !dbg !334 + %11041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 60, !dbg !334 + %11042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 61, !dbg !334 + %11043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 62, !dbg !334 + %11044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 63, !dbg !334 + %11045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 64, !dbg !334 + %11046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 65, !dbg !334 + %11047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 66, !dbg !334 + %11048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 67, !dbg !334 + %11049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 68, !dbg !334 + %11050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 69, !dbg !334 + %11051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 70, !dbg !334 + %11052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 71, !dbg !334 + %11053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 72, !dbg !334 + %11054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 73, !dbg !334 + %11055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 74, !dbg !334 + %11056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 75, !dbg !334 + %11057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 76, !dbg !334 + %11058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 77, !dbg !334 + %11059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 78, !dbg !334 + %11060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 79, !dbg !334 + %11061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 80, !dbg !334 + %11062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 81, !dbg !334 + %11063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 82, !dbg !334 + %11064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 83, !dbg !334 + %11065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 84, !dbg !334 + %11066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 85, !dbg !334 + %11067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 86, !dbg !334 + %11068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 87, !dbg !334 + %11069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 88, !dbg !334 + %11070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 89, !dbg !334 + %11071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 90, !dbg !334 + %11072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 91, !dbg !334 + %11073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 92, !dbg !334 + %11074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 93, !dbg !334 + %11075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 94, !dbg !334 + %11076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 95, !dbg !334 + %11077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 96, !dbg !334 + %11078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 97, !dbg !334 + %11079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 98, !dbg !334 + %11080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 99, !dbg !334 + %11081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 100, !dbg !334 + %11082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 101, !dbg !334 + %11083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 102, !dbg !334 + %11084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 103, !dbg !334 + %11085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 104, !dbg !334 + %11086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 105, !dbg !334 + %11087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 106, !dbg !334 + %11088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 107, !dbg !334 + %11089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 108, !dbg !334 + %11090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 109, !dbg !334 + %11091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 110, !dbg !334 + %11092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 111, !dbg !334 + %11093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 112, !dbg !334 + %11094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 113, !dbg !334 + %11095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 114, !dbg !334 + %11096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 115, !dbg !334 + %11097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 116, !dbg !334 + %11098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 117, !dbg !334 + %11099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 118, !dbg !334 + %11100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 119, !dbg !334 + %11101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 120, !dbg !334 + %11102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 121, !dbg !334 + %11103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 122, !dbg !334 + %11104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 123, !dbg !334 + %11105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 124, !dbg !334 + %11106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 125, !dbg !334 + %11107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 126, !dbg !334 + %11108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10980, 127, !dbg !334 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !334 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !334 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1, !dbg !251 + %exitcond2269.not = icmp eq i64 %indvars.iv.next, 4, !dbg !251 + br i1 %exitcond2269.not, label %11109, label %5585, !dbg !251 + +11109: ; preds = %._crit_edge1874 + %11110 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4758, !dbg !371 + %11111 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4760, !dbg !371 + %11112 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4762, !dbg !371 + %11113 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4764, !dbg !371 + %11114 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4766, !dbg !371 + %11115 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4768, !dbg !371 + %11116 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4770, !dbg !371 + %11117 = getelementptr bfloat, ptr addrspace(1) %49, i64 %4772, !dbg !371 + %11118 = getelementptr bfloat, ptr addrspace(1) %11110, i64 %4776, !dbg !372 + %11119 = getelementptr bfloat, ptr addrspace(1) %11111, i64 %4776, !dbg !372 + %11120 = getelementptr bfloat, ptr addrspace(1) %11112, i64 %4776, !dbg !372 + %11121 = getelementptr bfloat, ptr addrspace(1) %11113, i64 %4776, !dbg !372 + %11122 = getelementptr bfloat, ptr addrspace(1) %11114, i64 %4776, !dbg !372 + %11123 = getelementptr bfloat, ptr addrspace(1) %11115, i64 %4776, !dbg !372 + %11124 = getelementptr bfloat, ptr addrspace(1) %11116, i64 %4776, !dbg !372 + %11125 = getelementptr bfloat, ptr addrspace(1) %11117, i64 %4776, !dbg !372 + %11126 = insertelement <2 x float> poison, float %10981, i64 0, !dbg !373 + %11127 = insertelement <2 x float> %11126, float %10982, i64 1, !dbg !373 + %11128 = fptrunc <2 x float> %11127 to <2 x bfloat>, !dbg !373 + %11129 = insertelement <2 x float> poison, float %10983, i64 0, !dbg !373 + %11130 = insertelement <2 x float> %11129, float %10984, i64 1, !dbg !373 + %11131 = fptrunc <2 x float> %11130 to <2 x bfloat>, !dbg !373 + %11132 = insertelement <2 x float> poison, float %10985, i64 0, !dbg !373 + %11133 = insertelement <2 x float> %11132, float %10986, i64 1, !dbg !373 + %11134 = fptrunc <2 x float> %11133 to <2 x bfloat>, !dbg !373 + %11135 = insertelement <2 x float> poison, float %10987, i64 0, !dbg !373 + %11136 = insertelement <2 x float> %11135, float %10988, i64 1, !dbg !373 + %11137 = fptrunc <2 x float> %11136 to <2 x bfloat>, !dbg !373 + %11138 = insertelement <2 x float> poison, float %10989, i64 0, !dbg !373 + %11139 = insertelement <2 x float> %11138, float %10990, i64 1, !dbg !373 + %11140 = fptrunc <2 x float> %11139 to <2 x bfloat>, !dbg !373 + %11141 = insertelement <2 x float> poison, float %10991, i64 0, !dbg !373 + %11142 = insertelement <2 x float> %11141, float %10992, i64 1, !dbg !373 + %11143 = fptrunc <2 x float> %11142 to <2 x bfloat>, !dbg !373 + %11144 = insertelement <2 x float> poison, float %10993, i64 0, !dbg !373 + %11145 = insertelement <2 x float> %11144, float %10994, i64 1, !dbg !373 + %11146 = fptrunc <2 x float> %11145 to <2 x bfloat>, !dbg !373 + %11147 = insertelement <2 x float> poison, float %10995, i64 0, !dbg !373 + %11148 = insertelement <2 x float> %11147, float %10996, i64 1, !dbg !373 + %11149 = fptrunc <2 x float> %11148 to <2 x bfloat>, !dbg !373 + %11150 = insertelement <2 x float> poison, float %10997, i64 0, !dbg !373 + %11151 = insertelement <2 x float> %11150, float %10998, i64 1, !dbg !373 + %11152 = fptrunc <2 x float> %11151 to <2 x bfloat>, !dbg !373 + %11153 = insertelement <2 x float> poison, float %10999, i64 0, !dbg !373 + %11154 = insertelement <2 x float> %11153, float %11000, i64 1, !dbg !373 + %11155 = fptrunc <2 x float> %11154 to <2 x bfloat>, !dbg !373 + %11156 = insertelement <2 x float> poison, float %11001, i64 0, !dbg !373 + %11157 = insertelement <2 x float> %11156, float %11002, i64 1, !dbg !373 + %11158 = fptrunc <2 x float> %11157 to <2 x bfloat>, !dbg !373 + %11159 = insertelement <2 x float> poison, float %11003, i64 0, !dbg !373 + %11160 = insertelement <2 x float> %11159, float %11004, i64 1, !dbg !373 + %11161 = fptrunc <2 x float> %11160 to <2 x bfloat>, !dbg !373 + %11162 = insertelement <2 x float> poison, float %11005, i64 0, !dbg !373 + %11163 = insertelement <2 x float> %11162, float %11006, i64 1, !dbg !373 + %11164 = fptrunc <2 x float> %11163 to <2 x bfloat>, !dbg !373 + %11165 = insertelement <2 x float> poison, float %11007, i64 0, !dbg !373 + %11166 = insertelement <2 x float> %11165, float %11008, i64 1, !dbg !373 + %11167 = fptrunc <2 x float> %11166 to <2 x bfloat>, !dbg !373 + %11168 = insertelement <2 x float> poison, float %11009, i64 0, !dbg !373 + %11169 = insertelement <2 x float> %11168, float %11010, i64 1, !dbg !373 + %11170 = fptrunc <2 x float> %11169 to <2 x bfloat>, !dbg !373 + %11171 = insertelement <2 x float> poison, float %11011, i64 0, !dbg !373 + %11172 = insertelement <2 x float> %11171, float %11012, i64 1, !dbg !373 + %11173 = fptrunc <2 x float> %11172 to <2 x bfloat>, !dbg !373 + %11174 = insertelement <2 x float> poison, float %11013, i64 0, !dbg !373 + %11175 = insertelement <2 x float> %11174, float %11014, i64 1, !dbg !373 + %11176 = fptrunc <2 x float> %11175 to <2 x bfloat>, !dbg !373 + %11177 = insertelement <2 x float> poison, float %11015, i64 0, !dbg !373 + %11178 = insertelement <2 x float> %11177, float %11016, i64 1, !dbg !373 + %11179 = fptrunc <2 x float> %11178 to <2 x bfloat>, !dbg !373 + %11180 = insertelement <2 x float> poison, float %11017, i64 0, !dbg !373 + %11181 = insertelement <2 x float> %11180, float %11018, i64 1, !dbg !373 + %11182 = fptrunc <2 x float> %11181 to <2 x bfloat>, !dbg !373 + %11183 = insertelement <2 x float> poison, float %11019, i64 0, !dbg !373 + %11184 = insertelement <2 x float> %11183, float %11020, i64 1, !dbg !373 + %11185 = fptrunc <2 x float> %11184 to <2 x bfloat>, !dbg !373 + %11186 = insertelement <2 x float> poison, float %11021, i64 0, !dbg !373 + %11187 = insertelement <2 x float> %11186, float %11022, i64 1, !dbg !373 + %11188 = fptrunc <2 x float> %11187 to <2 x bfloat>, !dbg !373 + %11189 = insertelement <2 x float> poison, float %11023, i64 0, !dbg !373 + %11190 = insertelement <2 x float> %11189, float %11024, i64 1, !dbg !373 + %11191 = fptrunc <2 x float> %11190 to <2 x bfloat>, !dbg !373 + %11192 = insertelement <2 x float> poison, float %11025, i64 0, !dbg !373 + %11193 = insertelement <2 x float> %11192, float %11026, i64 1, !dbg !373 + %11194 = fptrunc <2 x float> %11193 to <2 x bfloat>, !dbg !373 + %11195 = insertelement <2 x float> poison, float %11027, i64 0, !dbg !373 + %11196 = insertelement <2 x float> %11195, float %11028, i64 1, !dbg !373 + %11197 = fptrunc <2 x float> %11196 to <2 x bfloat>, !dbg !373 + %11198 = insertelement <2 x float> poison, float %11029, i64 0, !dbg !373 + %11199 = insertelement <2 x float> %11198, float %11030, i64 1, !dbg !373 + %11200 = fptrunc <2 x float> %11199 to <2 x bfloat>, !dbg !373 + %11201 = insertelement <2 x float> poison, float %11031, i64 0, !dbg !373 + %11202 = insertelement <2 x float> %11201, float %11032, i64 1, !dbg !373 + %11203 = fptrunc <2 x float> %11202 to <2 x bfloat>, !dbg !373 + %11204 = insertelement <2 x float> poison, float %11033, i64 0, !dbg !373 + %11205 = insertelement <2 x float> %11204, float %11034, i64 1, !dbg !373 + %11206 = fptrunc <2 x float> %11205 to <2 x bfloat>, !dbg !373 + %11207 = insertelement <2 x float> poison, float %11035, i64 0, !dbg !373 + %11208 = insertelement <2 x float> %11207, float %11036, i64 1, !dbg !373 + %11209 = fptrunc <2 x float> %11208 to <2 x bfloat>, !dbg !373 + %11210 = insertelement <2 x float> poison, float %11037, i64 0, !dbg !373 + %11211 = insertelement <2 x float> %11210, float %11038, i64 1, !dbg !373 + %11212 = fptrunc <2 x float> %11211 to <2 x bfloat>, !dbg !373 + %11213 = insertelement <2 x float> poison, float %11039, i64 0, !dbg !373 + %11214 = insertelement <2 x float> %11213, float %11040, i64 1, !dbg !373 + %11215 = fptrunc <2 x float> %11214 to <2 x bfloat>, !dbg !373 + %11216 = insertelement <2 x float> poison, float %11041, i64 0, !dbg !373 + %11217 = insertelement <2 x float> %11216, float %11042, i64 1, !dbg !373 + %11218 = fptrunc <2 x float> %11217 to <2 x bfloat>, !dbg !373 + %11219 = insertelement <2 x float> poison, float %11043, i64 0, !dbg !373 + %11220 = insertelement <2 x float> %11219, float %11044, i64 1, !dbg !373 + %11221 = fptrunc <2 x float> %11220 to <2 x bfloat>, !dbg !373 + %11222 = shl nuw nsw i32 %4997, 13, !dbg !373 + %11223 = shl nuw nsw i32 %50, 5, !dbg !373 + %11224 = and i32 %11223, 7264, !dbg !373 + %11225 = and i32 %50, 24, !dbg !373 + %11226 = shl nuw nsw i32 %11225, 4, !dbg !373 + %11227 = shl nuw nsw i32 %50, 2, !dbg !373 + %11228 = and i32 %11227, 16, !dbg !373 + %11229 = or disjoint i32 %11222, %11228, !dbg !373 + %11230 = or disjoint i32 %11224, %11226, !dbg !373 + %11231 = or disjoint i32 %11229, %11230, !dbg !373 + %11232 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11231, !dbg !373 + %11233 = bitcast <2 x bfloat> %11128 to i32, !dbg !373 + %11234 = bitcast <2 x bfloat> %11134 to i32, !dbg !373 + %11235 = bitcast <2 x bfloat> %11140 to i32, !dbg !373 + %11236 = bitcast <2 x bfloat> %11146 to i32, !dbg !373 + %11237 = insertelement <4 x i32> poison, i32 %11233, i64 0, !dbg !373 + %11238 = insertelement <4 x i32> %11237, i32 %11234, i64 1, !dbg !373 + %11239 = insertelement <4 x i32> %11238, i32 %11235, i64 2, !dbg !373 + %11240 = insertelement <4 x i32> %11239, i32 %11236, i64 3, !dbg !373 + store <4 x i32> %11240, ptr addrspace(3) %11232, align 16, !dbg !373 + %11241 = getelementptr inbounds nuw i8, ptr addrspace(3) %11232, i32 512, !dbg !373 + %11242 = bitcast <2 x bfloat> %11131 to i32, !dbg !373 + %11243 = bitcast <2 x bfloat> %11137 to i32, !dbg !373 + %11244 = bitcast <2 x bfloat> %11143 to i32, !dbg !373 + %11245 = bitcast <2 x bfloat> %11149 to i32, !dbg !373 + %11246 = insertelement <4 x i32> poison, i32 %11242, i64 0, !dbg !373 + %11247 = insertelement <4 x i32> %11246, i32 %11243, i64 1, !dbg !373 + %11248 = insertelement <4 x i32> %11247, i32 %11244, i64 2, !dbg !373 + %11249 = insertelement <4 x i32> %11248, i32 %11245, i64 3, !dbg !373 + store <4 x i32> %11249, ptr addrspace(3) %11241, align 16, !dbg !373 + %11250 = xor i32 %11231, 32, !dbg !373 + %11251 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11250, !dbg !373 + %11252 = bitcast <2 x bfloat> %11152 to i32, !dbg !373 + %11253 = bitcast <2 x bfloat> %11158 to i32, !dbg !373 + %11254 = bitcast <2 x bfloat> %11164 to i32, !dbg !373 + %11255 = bitcast <2 x bfloat> %11170 to i32, !dbg !373 + %11256 = insertelement <4 x i32> poison, i32 %11252, i64 0, !dbg !373 + %11257 = insertelement <4 x i32> %11256, i32 %11253, i64 1, !dbg !373 + %11258 = insertelement <4 x i32> %11257, i32 %11254, i64 2, !dbg !373 + %11259 = insertelement <4 x i32> %11258, i32 %11255, i64 3, !dbg !373 + store <4 x i32> %11259, ptr addrspace(3) %11251, align 16, !dbg !373 + %11260 = getelementptr inbounds nuw i8, ptr addrspace(3) %11251, i32 512, !dbg !373 + %11261 = bitcast <2 x bfloat> %11155 to i32, !dbg !373 + %11262 = bitcast <2 x bfloat> %11161 to i32, !dbg !373 + %11263 = bitcast <2 x bfloat> %11167 to i32, !dbg !373 + %11264 = bitcast <2 x bfloat> %11173 to i32, !dbg !373 + %11265 = insertelement <4 x i32> poison, i32 %11261, i64 0, !dbg !373 + %11266 = insertelement <4 x i32> %11265, i32 %11262, i64 1, !dbg !373 + %11267 = insertelement <4 x i32> %11266, i32 %11263, i64 2, !dbg !373 + %11268 = insertelement <4 x i32> %11267, i32 %11264, i64 3, !dbg !373 + store <4 x i32> %11268, ptr addrspace(3) %11260, align 16, !dbg !373 + %11269 = xor i32 %11231, 64, !dbg !373 + %11270 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11269, !dbg !373 + %11271 = bitcast <2 x bfloat> %11176 to i32, !dbg !373 + %11272 = bitcast <2 x bfloat> %11182 to i32, !dbg !373 + %11273 = bitcast <2 x bfloat> %11188 to i32, !dbg !373 + %11274 = bitcast <2 x bfloat> %11194 to i32, !dbg !373 + %11275 = insertelement <4 x i32> poison, i32 %11271, i64 0, !dbg !373 + %11276 = insertelement <4 x i32> %11275, i32 %11272, i64 1, !dbg !373 + %11277 = insertelement <4 x i32> %11276, i32 %11273, i64 2, !dbg !373 + %11278 = insertelement <4 x i32> %11277, i32 %11274, i64 3, !dbg !373 + store <4 x i32> %11278, ptr addrspace(3) %11270, align 16, !dbg !373 + %11279 = getelementptr inbounds nuw i8, ptr addrspace(3) %11270, i32 512, !dbg !373 + %11280 = bitcast <2 x bfloat> %11179 to i32, !dbg !373 + %11281 = bitcast <2 x bfloat> %11185 to i32, !dbg !373 + %11282 = bitcast <2 x bfloat> %11191 to i32, !dbg !373 + %11283 = bitcast <2 x bfloat> %11197 to i32, !dbg !373 + %11284 = insertelement <4 x i32> poison, i32 %11280, i64 0, !dbg !373 + %11285 = insertelement <4 x i32> %11284, i32 %11281, i64 1, !dbg !373 + %11286 = insertelement <4 x i32> %11285, i32 %11282, i64 2, !dbg !373 + %11287 = insertelement <4 x i32> %11286, i32 %11283, i64 3, !dbg !373 + store <4 x i32> %11287, ptr addrspace(3) %11279, align 16, !dbg !373 + %11288 = xor i32 %11231, 96, !dbg !373 + %11289 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11288, !dbg !373 + %11290 = bitcast <2 x bfloat> %11200 to i32, !dbg !373 + %11291 = bitcast <2 x bfloat> %11206 to i32, !dbg !373 + %11292 = bitcast <2 x bfloat> %11212 to i32, !dbg !373 + %11293 = bitcast <2 x bfloat> %11218 to i32, !dbg !373 + %11294 = insertelement <4 x i32> poison, i32 %11290, i64 0, !dbg !373 + %11295 = insertelement <4 x i32> %11294, i32 %11291, i64 1, !dbg !373 + %11296 = insertelement <4 x i32> %11295, i32 %11292, i64 2, !dbg !373 + %11297 = insertelement <4 x i32> %11296, i32 %11293, i64 3, !dbg !373 + store <4 x i32> %11297, ptr addrspace(3) %11289, align 16, !dbg !373 + %11298 = getelementptr inbounds nuw i8, ptr addrspace(3) %11289, i32 512, !dbg !373 + %11299 = bitcast <2 x bfloat> %11203 to i32, !dbg !373 + %11300 = bitcast <2 x bfloat> %11209 to i32, !dbg !373 + %11301 = bitcast <2 x bfloat> %11215 to i32, !dbg !373 + %11302 = bitcast <2 x bfloat> %11221 to i32, !dbg !373 + %11303 = insertelement <4 x i32> poison, i32 %11299, i64 0, !dbg !373 + %11304 = insertelement <4 x i32> %11303, i32 %11300, i64 1, !dbg !373 + %11305 = insertelement <4 x i32> %11304, i32 %11301, i64 2, !dbg !373 + %11306 = insertelement <4 x i32> %11305, i32 %11302, i64 3, !dbg !373 + store <4 x i32> %11306, ptr addrspace(3) %11298, align 16, !dbg !373 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !373 + %11307 = shl nuw nsw i32 %11225, 10, !dbg !373 + %11308 = shl nuw nsw i32 %4997, 5, !dbg !373 + %11309 = and i32 %11227, 1008, !dbg !373 + %11310 = or disjoint i32 %11307, %11308, !dbg !373 + %11311 = xor i32 %11310, %11309, !dbg !373 + %11312 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11311, !dbg !373 + %11313 = ptrtoint ptr addrspace(3) %11312 to i32, !dbg !373 + %11314 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11313) #3, !dbg !373 + %11315 = extractvalue { i32, i32, i32, i32 } %11314, 0, !dbg !373 + %11316 = extractvalue { i32, i32, i32, i32 } %11314, 1, !dbg !373 + %11317 = extractvalue { i32, i32, i32, i32 } %11314, 2, !dbg !373 + %11318 = extractvalue { i32, i32, i32, i32 } %11314, 3, !dbg !373 + %11319 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 1024, !dbg !373 + %11320 = ptrtoint ptr addrspace(3) %11319 to i32, !dbg !373 + %11321 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11320) #3, !dbg !373 + %11322 = extractvalue { i32, i32, i32, i32 } %11321, 0, !dbg !373 + %11323 = extractvalue { i32, i32, i32, i32 } %11321, 1, !dbg !373 + %11324 = extractvalue { i32, i32, i32, i32 } %11321, 2, !dbg !373 + %11325 = extractvalue { i32, i32, i32, i32 } %11321, 3, !dbg !373 + %11326 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 2048, !dbg !373 + %11327 = ptrtoint ptr addrspace(3) %11326 to i32, !dbg !373 + %11328 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11327) #3, !dbg !373 + %11329 = extractvalue { i32, i32, i32, i32 } %11328, 0, !dbg !373 + %11330 = extractvalue { i32, i32, i32, i32 } %11328, 1, !dbg !373 + %11331 = extractvalue { i32, i32, i32, i32 } %11328, 2, !dbg !373 + %11332 = extractvalue { i32, i32, i32, i32 } %11328, 3, !dbg !373 + %11333 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 3072, !dbg !373 + %11334 = ptrtoint ptr addrspace(3) %11333 to i32, !dbg !373 + %11335 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11334) #3, !dbg !373 + %11336 = extractvalue { i32, i32, i32, i32 } %11335, 0, !dbg !373 + %11337 = extractvalue { i32, i32, i32, i32 } %11335, 1, !dbg !373 + %11338 = extractvalue { i32, i32, i32, i32 } %11335, 2, !dbg !373 + %11339 = extractvalue { i32, i32, i32, i32 } %11335, 3, !dbg !373 + %11340 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 4096, !dbg !373 + %11341 = ptrtoint ptr addrspace(3) %11340 to i32, !dbg !373 + %11342 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11341) #3, !dbg !373 + %11343 = extractvalue { i32, i32, i32, i32 } %11342, 0, !dbg !373 + %11344 = extractvalue { i32, i32, i32, i32 } %11342, 1, !dbg !373 + %11345 = extractvalue { i32, i32, i32, i32 } %11342, 2, !dbg !373 + %11346 = extractvalue { i32, i32, i32, i32 } %11342, 3, !dbg !373 + %11347 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 5120, !dbg !373 + %11348 = ptrtoint ptr addrspace(3) %11347 to i32, !dbg !373 + %11349 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11348) #3, !dbg !373 + %11350 = extractvalue { i32, i32, i32, i32 } %11349, 0, !dbg !373 + %11351 = extractvalue { i32, i32, i32, i32 } %11349, 1, !dbg !373 + %11352 = extractvalue { i32, i32, i32, i32 } %11349, 2, !dbg !373 + %11353 = extractvalue { i32, i32, i32, i32 } %11349, 3, !dbg !373 + %11354 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 6144, !dbg !373 + %11355 = ptrtoint ptr addrspace(3) %11354 to i32, !dbg !373 + %11356 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11355) #3, !dbg !373 + %11357 = extractvalue { i32, i32, i32, i32 } %11356, 0, !dbg !373 + %11358 = extractvalue { i32, i32, i32, i32 } %11356, 1, !dbg !373 + %11359 = extractvalue { i32, i32, i32, i32 } %11356, 2, !dbg !373 + %11360 = extractvalue { i32, i32, i32, i32 } %11356, 3, !dbg !373 + %11361 = getelementptr inbounds nuw i8, ptr addrspace(3) %11312, i32 7168, !dbg !373 + %11362 = ptrtoint ptr addrspace(3) %11361 to i32, !dbg !373 + %11363 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11362) #3, !dbg !373 + %11364 = extractvalue { i32, i32, i32, i32 } %11363, 0, !dbg !373 + %11365 = extractvalue { i32, i32, i32, i32 } %11363, 1, !dbg !373 + %11366 = extractvalue { i32, i32, i32, i32 } %11363, 2, !dbg !373 + %11367 = extractvalue { i32, i32, i32, i32 } %11363, 3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11315, i32 %11316, i32 %11317, i32 %11318, ptr addrspace(1) %11118, i1 %4785) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11322, i32 %11323, i32 %11324, i32 %11325, ptr addrspace(1) %11119, i1 %4786) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11329, i32 %11330, i32 %11331, i32 %11332, ptr addrspace(1) %11120, i1 %4787) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11336, i32 %11337, i32 %11338, i32 %11339, ptr addrspace(1) %11121, i1 %4788) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11343, i32 %11344, i32 %11345, i32 %11346, ptr addrspace(1) %11122, i1 %4789) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11350, i32 %11351, i32 %11352, i32 %11353, ptr addrspace(1) %11123, i1 %4790) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11357, i32 %11358, i32 %11359, i32 %11360, ptr addrspace(1) %11124, i1 %4791) #3, !dbg !373 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11364, i32 %11365, i32 %11366, i32 %11367, ptr addrspace(1) %11125, i1 %4792) #3, !dbg !373 + %11368 = insertelement <2 x float> poison, float %11045, i64 0, !dbg !374 + %11369 = insertelement <2 x float> %11368, float %11046, i64 1, !dbg !374 + %11370 = fmul <2 x float> %11369, splat (float 0x3FB6A09E60000000), !dbg !374 + %11371 = insertelement <2 x float> poison, float %11047, i64 0, !dbg !374 + %11372 = insertelement <2 x float> %11371, float %11048, i64 1, !dbg !374 + %11373 = fmul <2 x float> %11372, splat (float 0x3FB6A09E60000000), !dbg !374 + %11374 = insertelement <2 x float> poison, float %11049, i64 0, !dbg !374 + %11375 = insertelement <2 x float> %11374, float %11050, i64 1, !dbg !374 + %11376 = fmul <2 x float> %11375, splat (float 0x3FB6A09E60000000), !dbg !374 + %11377 = insertelement <2 x float> poison, float %11051, i64 0, !dbg !374 + %11378 = insertelement <2 x float> %11377, float %11052, i64 1, !dbg !374 + %11379 = fmul <2 x float> %11378, splat (float 0x3FB6A09E60000000), !dbg !374 + %11380 = insertelement <2 x float> poison, float %11053, i64 0, !dbg !374 + %11381 = insertelement <2 x float> %11380, float %11054, i64 1, !dbg !374 + %11382 = fmul <2 x float> %11381, splat (float 0x3FB6A09E60000000), !dbg !374 + %11383 = insertelement <2 x float> poison, float %11055, i64 0, !dbg !374 + %11384 = insertelement <2 x float> %11383, float %11056, i64 1, !dbg !374 + %11385 = fmul <2 x float> %11384, splat (float 0x3FB6A09E60000000), !dbg !374 + %11386 = insertelement <2 x float> poison, float %11057, i64 0, !dbg !374 + %11387 = insertelement <2 x float> %11386, float %11058, i64 1, !dbg !374 + %11388 = fmul <2 x float> %11387, splat (float 0x3FB6A09E60000000), !dbg !374 + %11389 = insertelement <2 x float> poison, float %11059, i64 0, !dbg !374 + %11390 = insertelement <2 x float> %11389, float %11060, i64 1, !dbg !374 + %11391 = fmul <2 x float> %11390, splat (float 0x3FB6A09E60000000), !dbg !374 + %11392 = insertelement <2 x float> poison, float %11061, i64 0, !dbg !374 + %11393 = insertelement <2 x float> %11392, float %11062, i64 1, !dbg !374 + %11394 = fmul <2 x float> %11393, splat (float 0x3FB6A09E60000000), !dbg !374 + %11395 = insertelement <2 x float> poison, float %11063, i64 0, !dbg !374 + %11396 = insertelement <2 x float> %11395, float %11064, i64 1, !dbg !374 + %11397 = fmul <2 x float> %11396, splat (float 0x3FB6A09E60000000), !dbg !374 + %11398 = insertelement <2 x float> poison, float %11065, i64 0, !dbg !374 + %11399 = insertelement <2 x float> %11398, float %11066, i64 1, !dbg !374 + %11400 = fmul <2 x float> %11399, splat (float 0x3FB6A09E60000000), !dbg !374 + %11401 = insertelement <2 x float> poison, float %11067, i64 0, !dbg !374 + %11402 = insertelement <2 x float> %11401, float %11068, i64 1, !dbg !374 + %11403 = fmul <2 x float> %11402, splat (float 0x3FB6A09E60000000), !dbg !374 + %11404 = insertelement <2 x float> poison, float %11069, i64 0, !dbg !374 + %11405 = insertelement <2 x float> %11404, float %11070, i64 1, !dbg !374 + %11406 = fmul <2 x float> %11405, splat (float 0x3FB6A09E60000000), !dbg !374 + %11407 = insertelement <2 x float> poison, float %11071, i64 0, !dbg !374 + %11408 = insertelement <2 x float> %11407, float %11072, i64 1, !dbg !374 + %11409 = fmul <2 x float> %11408, splat (float 0x3FB6A09E60000000), !dbg !374 + %11410 = insertelement <2 x float> poison, float %11073, i64 0, !dbg !374 + %11411 = insertelement <2 x float> %11410, float %11074, i64 1, !dbg !374 + %11412 = fmul <2 x float> %11411, splat (float 0x3FB6A09E60000000), !dbg !374 + %11413 = insertelement <2 x float> poison, float %11075, i64 0, !dbg !374 + %11414 = insertelement <2 x float> %11413, float %11076, i64 1, !dbg !374 + %11415 = fmul <2 x float> %11414, splat (float 0x3FB6A09E60000000), !dbg !374 + %11416 = insertelement <2 x float> poison, float %11077, i64 0, !dbg !374 + %11417 = insertelement <2 x float> %11416, float %11078, i64 1, !dbg !374 + %11418 = fmul <2 x float> %11417, splat (float 0x3FB6A09E60000000), !dbg !374 + %11419 = insertelement <2 x float> poison, float %11079, i64 0, !dbg !374 + %11420 = insertelement <2 x float> %11419, float %11080, i64 1, !dbg !374 + %11421 = fmul <2 x float> %11420, splat (float 0x3FB6A09E60000000), !dbg !374 + %11422 = insertelement <2 x float> poison, float %11081, i64 0, !dbg !374 + %11423 = insertelement <2 x float> %11422, float %11082, i64 1, !dbg !374 + %11424 = fmul <2 x float> %11423, splat (float 0x3FB6A09E60000000), !dbg !374 + %11425 = insertelement <2 x float> poison, float %11083, i64 0, !dbg !374 + %11426 = insertelement <2 x float> %11425, float %11084, i64 1, !dbg !374 + %11427 = fmul <2 x float> %11426, splat (float 0x3FB6A09E60000000), !dbg !374 + %11428 = insertelement <2 x float> poison, float %11085, i64 0, !dbg !374 + %11429 = insertelement <2 x float> %11428, float %11086, i64 1, !dbg !374 + %11430 = fmul <2 x float> %11429, splat (float 0x3FB6A09E60000000), !dbg !374 + %11431 = insertelement <2 x float> poison, float %11087, i64 0, !dbg !374 + %11432 = insertelement <2 x float> %11431, float %11088, i64 1, !dbg !374 + %11433 = fmul <2 x float> %11432, splat (float 0x3FB6A09E60000000), !dbg !374 + %11434 = insertelement <2 x float> poison, float %11089, i64 0, !dbg !374 + %11435 = insertelement <2 x float> %11434, float %11090, i64 1, !dbg !374 + %11436 = fmul <2 x float> %11435, splat (float 0x3FB6A09E60000000), !dbg !374 + %11437 = insertelement <2 x float> poison, float %11091, i64 0, !dbg !374 + %11438 = insertelement <2 x float> %11437, float %11092, i64 1, !dbg !374 + %11439 = fmul <2 x float> %11438, splat (float 0x3FB6A09E60000000), !dbg !374 + %11440 = insertelement <2 x float> poison, float %11093, i64 0, !dbg !374 + %11441 = insertelement <2 x float> %11440, float %11094, i64 1, !dbg !374 + %11442 = fmul <2 x float> %11441, splat (float 0x3FB6A09E60000000), !dbg !374 + %11443 = insertelement <2 x float> poison, float %11095, i64 0, !dbg !374 + %11444 = insertelement <2 x float> %11443, float %11096, i64 1, !dbg !374 + %11445 = fmul <2 x float> %11444, splat (float 0x3FB6A09E60000000), !dbg !374 + %11446 = insertelement <2 x float> poison, float %11097, i64 0, !dbg !374 + %11447 = insertelement <2 x float> %11446, float %11098, i64 1, !dbg !374 + %11448 = fmul <2 x float> %11447, splat (float 0x3FB6A09E60000000), !dbg !374 + %11449 = insertelement <2 x float> poison, float %11099, i64 0, !dbg !374 + %11450 = insertelement <2 x float> %11449, float %11100, i64 1, !dbg !374 + %11451 = fmul <2 x float> %11450, splat (float 0x3FB6A09E60000000), !dbg !374 + %11452 = insertelement <2 x float> poison, float %11101, i64 0, !dbg !374 + %11453 = insertelement <2 x float> %11452, float %11102, i64 1, !dbg !374 + %11454 = fmul <2 x float> %11453, splat (float 0x3FB6A09E60000000), !dbg !374 + %11455 = insertelement <2 x float> poison, float %11103, i64 0, !dbg !374 + %11456 = insertelement <2 x float> %11455, float %11104, i64 1, !dbg !374 + %11457 = fmul <2 x float> %11456, splat (float 0x3FB6A09E60000000), !dbg !374 + %11458 = insertelement <2 x float> poison, float %11105, i64 0, !dbg !374 + %11459 = insertelement <2 x float> %11458, float %11106, i64 1, !dbg !374 + %11460 = fmul <2 x float> %11459, splat (float 0x3FB6A09E60000000), !dbg !374 + %11461 = insertelement <2 x float> poison, float %11107, i64 0, !dbg !374 + %11462 = insertelement <2 x float> %11461, float %11108, i64 1, !dbg !374 + %11463 = fmul <2 x float> %11462, splat (float 0x3FB6A09E60000000), !dbg !374 + %11464 = or disjoint i32 %4775, %41, !dbg !375 + %11465 = add i32 %4750, %11464, !dbg !376 + %11466 = add i32 %4751, %11464, !dbg !376 + %11467 = add i32 %4752, %11464, !dbg !376 + %11468 = add i32 %4753, %11464, !dbg !376 + %11469 = add i32 %4754, %11464, !dbg !376 + %11470 = add i32 %4755, %11464, !dbg !376 + %11471 = add i32 %4756, %11464, !dbg !376 + %11472 = add i32 %4757, %11464, !dbg !376 + %11473 = sext i32 %11465 to i64, !dbg !377 + %11474 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11473, !dbg !377 + %11475 = sext i32 %11466 to i64, !dbg !377 + %11476 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11475, !dbg !377 + %11477 = sext i32 %11467 to i64, !dbg !377 + %11478 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11477, !dbg !377 + %11479 = sext i32 %11468 to i64, !dbg !377 + %11480 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11479, !dbg !377 + %11481 = sext i32 %11469 to i64, !dbg !377 + %11482 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11481, !dbg !377 + %11483 = sext i32 %11470 to i64, !dbg !377 + %11484 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11483, !dbg !377 + %11485 = sext i32 %11471 to i64, !dbg !377 + %11486 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11485, !dbg !377 + %11487 = sext i32 %11472 to i64, !dbg !377 + %11488 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11487, !dbg !377 + %11489 = fptrunc <2 x float> %11370 to <2 x bfloat>, !dbg !378 + %11490 = fptrunc <2 x float> %11373 to <2 x bfloat>, !dbg !378 + %11491 = fptrunc <2 x float> %11376 to <2 x bfloat>, !dbg !378 + %11492 = fptrunc <2 x float> %11379 to <2 x bfloat>, !dbg !378 + %11493 = fptrunc <2 x float> %11382 to <2 x bfloat>, !dbg !378 + %11494 = fptrunc <2 x float> %11385 to <2 x bfloat>, !dbg !378 + %11495 = fptrunc <2 x float> %11388 to <2 x bfloat>, !dbg !378 + %11496 = fptrunc <2 x float> %11391 to <2 x bfloat>, !dbg !378 + %11497 = fptrunc <2 x float> %11394 to <2 x bfloat>, !dbg !378 + %11498 = fptrunc <2 x float> %11397 to <2 x bfloat>, !dbg !378 + %11499 = fptrunc <2 x float> %11400 to <2 x bfloat>, !dbg !378 + %11500 = fptrunc <2 x float> %11403 to <2 x bfloat>, !dbg !378 + %11501 = fptrunc <2 x float> %11406 to <2 x bfloat>, !dbg !378 + %11502 = fptrunc <2 x float> %11409 to <2 x bfloat>, !dbg !378 + %11503 = fptrunc <2 x float> %11412 to <2 x bfloat>, !dbg !378 + %11504 = fptrunc <2 x float> %11415 to <2 x bfloat>, !dbg !378 + %11505 = fptrunc <2 x float> %11418 to <2 x bfloat>, !dbg !378 + %11506 = fptrunc <2 x float> %11421 to <2 x bfloat>, !dbg !378 + %11507 = fptrunc <2 x float> %11424 to <2 x bfloat>, !dbg !378 + %11508 = fptrunc <2 x float> %11427 to <2 x bfloat>, !dbg !378 + %11509 = fptrunc <2 x float> %11430 to <2 x bfloat>, !dbg !378 + %11510 = fptrunc <2 x float> %11433 to <2 x bfloat>, !dbg !378 + %11511 = fptrunc <2 x float> %11436 to <2 x bfloat>, !dbg !378 + %11512 = fptrunc <2 x float> %11439 to <2 x bfloat>, !dbg !378 + %11513 = fptrunc <2 x float> %11442 to <2 x bfloat>, !dbg !378 + %11514 = fptrunc <2 x float> %11445 to <2 x bfloat>, !dbg !378 + %11515 = fptrunc <2 x float> %11448 to <2 x bfloat>, !dbg !378 + %11516 = fptrunc <2 x float> %11451 to <2 x bfloat>, !dbg !378 + %11517 = fptrunc <2 x float> %11454 to <2 x bfloat>, !dbg !378 + %11518 = fptrunc <2 x float> %11457 to <2 x bfloat>, !dbg !378 + %11519 = fptrunc <2 x float> %11460 to <2 x bfloat>, !dbg !378 + %11520 = fptrunc <2 x float> %11463 to <2 x bfloat>, !dbg !378 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !378 + %11521 = bitcast <2 x bfloat> %11489 to i32, !dbg !378 + %11522 = bitcast <2 x bfloat> %11491 to i32, !dbg !378 + %11523 = bitcast <2 x bfloat> %11493 to i32, !dbg !378 + %11524 = bitcast <2 x bfloat> %11495 to i32, !dbg !378 + %11525 = insertelement <4 x i32> poison, i32 %11521, i64 0, !dbg !378 + %11526 = insertelement <4 x i32> %11525, i32 %11522, i64 1, !dbg !378 + %11527 = insertelement <4 x i32> %11526, i32 %11523, i64 2, !dbg !378 + %11528 = insertelement <4 x i32> %11527, i32 %11524, i64 3, !dbg !378 + store <4 x i32> %11528, ptr addrspace(3) %11232, align 16, !dbg !378 + %11529 = bitcast <2 x bfloat> %11490 to i32, !dbg !378 + %11530 = bitcast <2 x bfloat> %11492 to i32, !dbg !378 + %11531 = bitcast <2 x bfloat> %11494 to i32, !dbg !378 + %11532 = bitcast <2 x bfloat> %11496 to i32, !dbg !378 + %11533 = insertelement <4 x i32> poison, i32 %11529, i64 0, !dbg !378 + %11534 = insertelement <4 x i32> %11533, i32 %11530, i64 1, !dbg !378 + %11535 = insertelement <4 x i32> %11534, i32 %11531, i64 2, !dbg !378 + %11536 = insertelement <4 x i32> %11535, i32 %11532, i64 3, !dbg !378 + store <4 x i32> %11536, ptr addrspace(3) %11241, align 16, !dbg !378 + %11537 = bitcast <2 x bfloat> %11497 to i32, !dbg !378 + %11538 = bitcast <2 x bfloat> %11499 to i32, !dbg !378 + %11539 = bitcast <2 x bfloat> %11501 to i32, !dbg !378 + %11540 = bitcast <2 x bfloat> %11503 to i32, !dbg !378 + %11541 = insertelement <4 x i32> poison, i32 %11537, i64 0, !dbg !378 + %11542 = insertelement <4 x i32> %11541, i32 %11538, i64 1, !dbg !378 + %11543 = insertelement <4 x i32> %11542, i32 %11539, i64 2, !dbg !378 + %11544 = insertelement <4 x i32> %11543, i32 %11540, i64 3, !dbg !378 + store <4 x i32> %11544, ptr addrspace(3) %11251, align 16, !dbg !378 + %11545 = bitcast <2 x bfloat> %11498 to i32, !dbg !378 + %11546 = bitcast <2 x bfloat> %11500 to i32, !dbg !378 + %11547 = bitcast <2 x bfloat> %11502 to i32, !dbg !378 + %11548 = bitcast <2 x bfloat> %11504 to i32, !dbg !378 + %11549 = insertelement <4 x i32> poison, i32 %11545, i64 0, !dbg !378 + %11550 = insertelement <4 x i32> %11549, i32 %11546, i64 1, !dbg !378 + %11551 = insertelement <4 x i32> %11550, i32 %11547, i64 2, !dbg !378 + %11552 = insertelement <4 x i32> %11551, i32 %11548, i64 3, !dbg !378 + store <4 x i32> %11552, ptr addrspace(3) %11260, align 16, !dbg !378 + %11553 = bitcast <2 x bfloat> %11505 to i32, !dbg !378 + %11554 = bitcast <2 x bfloat> %11507 to i32, !dbg !378 + %11555 = bitcast <2 x bfloat> %11509 to i32, !dbg !378 + %11556 = bitcast <2 x bfloat> %11511 to i32, !dbg !378 + %11557 = insertelement <4 x i32> poison, i32 %11553, i64 0, !dbg !378 + %11558 = insertelement <4 x i32> %11557, i32 %11554, i64 1, !dbg !378 + %11559 = insertelement <4 x i32> %11558, i32 %11555, i64 2, !dbg !378 + %11560 = insertelement <4 x i32> %11559, i32 %11556, i64 3, !dbg !378 + store <4 x i32> %11560, ptr addrspace(3) %11270, align 16, !dbg !378 + %11561 = bitcast <2 x bfloat> %11506 to i32, !dbg !378 + %11562 = bitcast <2 x bfloat> %11508 to i32, !dbg !378 + %11563 = bitcast <2 x bfloat> %11510 to i32, !dbg !378 + %11564 = bitcast <2 x bfloat> %11512 to i32, !dbg !378 + %11565 = insertelement <4 x i32> poison, i32 %11561, i64 0, !dbg !378 + %11566 = insertelement <4 x i32> %11565, i32 %11562, i64 1, !dbg !378 + %11567 = insertelement <4 x i32> %11566, i32 %11563, i64 2, !dbg !378 + %11568 = insertelement <4 x i32> %11567, i32 %11564, i64 3, !dbg !378 + store <4 x i32> %11568, ptr addrspace(3) %11279, align 16, !dbg !378 + %11569 = bitcast <2 x bfloat> %11513 to i32, !dbg !378 + %11570 = bitcast <2 x bfloat> %11515 to i32, !dbg !378 + %11571 = bitcast <2 x bfloat> %11517 to i32, !dbg !378 + %11572 = bitcast <2 x bfloat> %11519 to i32, !dbg !378 + %11573 = insertelement <4 x i32> poison, i32 %11569, i64 0, !dbg !378 + %11574 = insertelement <4 x i32> %11573, i32 %11570, i64 1, !dbg !378 + %11575 = insertelement <4 x i32> %11574, i32 %11571, i64 2, !dbg !378 + %11576 = insertelement <4 x i32> %11575, i32 %11572, i64 3, !dbg !378 + store <4 x i32> %11576, ptr addrspace(3) %11289, align 16, !dbg !378 + %11577 = bitcast <2 x bfloat> %11514 to i32, !dbg !378 + %11578 = bitcast <2 x bfloat> %11516 to i32, !dbg !378 + %11579 = bitcast <2 x bfloat> %11518 to i32, !dbg !378 + %11580 = bitcast <2 x bfloat> %11520 to i32, !dbg !378 + %11581 = insertelement <4 x i32> poison, i32 %11577, i64 0, !dbg !378 + %11582 = insertelement <4 x i32> %11581, i32 %11578, i64 1, !dbg !378 + %11583 = insertelement <4 x i32> %11582, i32 %11579, i64 2, !dbg !378 + %11584 = insertelement <4 x i32> %11583, i32 %11580, i64 3, !dbg !378 + store <4 x i32> %11584, ptr addrspace(3) %11298, align 16, !dbg !378 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !378 + %11585 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11313) #3, !dbg !378 + %11586 = extractvalue { i32, i32, i32, i32 } %11585, 0, !dbg !378 + %11587 = extractvalue { i32, i32, i32, i32 } %11585, 1, !dbg !378 + %11588 = extractvalue { i32, i32, i32, i32 } %11585, 2, !dbg !378 + %11589 = extractvalue { i32, i32, i32, i32 } %11585, 3, !dbg !378 + %11590 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11320) #3, !dbg !378 + %11591 = extractvalue { i32, i32, i32, i32 } %11590, 0, !dbg !378 + %11592 = extractvalue { i32, i32, i32, i32 } %11590, 1, !dbg !378 + %11593 = extractvalue { i32, i32, i32, i32 } %11590, 2, !dbg !378 + %11594 = extractvalue { i32, i32, i32, i32 } %11590, 3, !dbg !378 + %11595 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11327) #3, !dbg !378 + %11596 = extractvalue { i32, i32, i32, i32 } %11595, 0, !dbg !378 + %11597 = extractvalue { i32, i32, i32, i32 } %11595, 1, !dbg !378 + %11598 = extractvalue { i32, i32, i32, i32 } %11595, 2, !dbg !378 + %11599 = extractvalue { i32, i32, i32, i32 } %11595, 3, !dbg !378 + %11600 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11334) #3, !dbg !378 + %11601 = extractvalue { i32, i32, i32, i32 } %11600, 0, !dbg !378 + %11602 = extractvalue { i32, i32, i32, i32 } %11600, 1, !dbg !378 + %11603 = extractvalue { i32, i32, i32, i32 } %11600, 2, !dbg !378 + %11604 = extractvalue { i32, i32, i32, i32 } %11600, 3, !dbg !378 + %11605 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11341) #3, !dbg !378 + %11606 = extractvalue { i32, i32, i32, i32 } %11605, 0, !dbg !378 + %11607 = extractvalue { i32, i32, i32, i32 } %11605, 1, !dbg !378 + %11608 = extractvalue { i32, i32, i32, i32 } %11605, 2, !dbg !378 + %11609 = extractvalue { i32, i32, i32, i32 } %11605, 3, !dbg !378 + %11610 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11348) #3, !dbg !378 + %11611 = extractvalue { i32, i32, i32, i32 } %11610, 0, !dbg !378 + %11612 = extractvalue { i32, i32, i32, i32 } %11610, 1, !dbg !378 + %11613 = extractvalue { i32, i32, i32, i32 } %11610, 2, !dbg !378 + %11614 = extractvalue { i32, i32, i32, i32 } %11610, 3, !dbg !378 + %11615 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11355) #3, !dbg !378 + %11616 = extractvalue { i32, i32, i32, i32 } %11615, 0, !dbg !378 + %11617 = extractvalue { i32, i32, i32, i32 } %11615, 1, !dbg !378 + %11618 = extractvalue { i32, i32, i32, i32 } %11615, 2, !dbg !378 + %11619 = extractvalue { i32, i32, i32, i32 } %11615, 3, !dbg !378 + %11620 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11362) #3, !dbg !378 + %11621 = extractvalue { i32, i32, i32, i32 } %11620, 0, !dbg !378 + %11622 = extractvalue { i32, i32, i32, i32 } %11620, 1, !dbg !378 + %11623 = extractvalue { i32, i32, i32, i32 } %11620, 2, !dbg !378 + %11624 = extractvalue { i32, i32, i32, i32 } %11620, 3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11586, i32 %11587, i32 %11588, i32 %11589, ptr addrspace(1) %11474, i1 %4785) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11591, i32 %11592, i32 %11593, i32 %11594, ptr addrspace(1) %11476, i1 %4786) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11596, i32 %11597, i32 %11598, i32 %11599, ptr addrspace(1) %11478, i1 %4787) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11601, i32 %11602, i32 %11603, i32 %11604, ptr addrspace(1) %11480, i1 %4788) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11606, i32 %11607, i32 %11608, i32 %11609, ptr addrspace(1) %11482, i1 %4789) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11611, i32 %11612, i32 %11613, i32 %11614, ptr addrspace(1) %11484, i1 %4790) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11616, i32 %11617, i32 %11618, i32 %11619, ptr addrspace(1) %11486, i1 %4791) #3, !dbg !378 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11621, i32 %11622, i32 %11623, i32 %11624, ptr addrspace(1) %11488, i1 %4792) #3, !dbg !378 + br label %11625, !dbg !35 + +11625: ; preds = %._crit_edge1673, %11109 + ret void, !dbg !379 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smax.i32(i32, i32) #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smin.i32(i32, i32) #1 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #2 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.commit.group() #3 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.wait.group(i32 immarg) #3 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.idx.i32(i32, i32, i32, i32) #4 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.fence.sync.aligned() #5 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.commit_group.sync.aligned() #5 + +declare i32 @__nvvm_reflect(ptr) local_unnamed_addr #6 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.ftz.f(float) #7 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.f(float) #7 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind } +attributes #3 = { nounwind } +attributes #4 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #5 = { convergent nounwind } +attributes #6 = { "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #7 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py", directory: "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = distinct !DISubprogram(name: "triton_tem_fused_mul_1", linkageName: "triton_tem_fused_mul_1", scope: !1, file: !1, line: 18, type: !6, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!6 = !DISubroutineType(cc: DW_CC_normal, types: !7) +!7 = !{} +!8 = !DILocation(line: 94, column: 54, scope: !5) +!9 = !DILocation(line: 97, column: 74, scope: !5) +!10 = !DILocation(line: 97, column: 66, scope: !5) +!11 = !DILocation(line: 97, column: 100, scope: !5) +!12 = !DILocation(line: 97, column: 91, scope: !5) +!13 = !DILocation(line: 97, column: 82, scope: !5) +!14 = !DILocation(line: 97, column: 59, scope: !5) +!15 = !DILocation(line: 97, column: 111, scope: !5) +!16 = !DILocation(line: 111, column: 24, scope: !5) +!17 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !20) +!18 = distinct !DILexicalBlockFile(scope: !5, file: !19, discriminator: 0) +!19 = !DIFile(filename: "standard.py", directory: "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language") +!20 = !DILocation(line: 112, column: 36, scope: !5) +!21 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !20) +!22 = !DILocation(line: 115, column: 27, scope: !5) +!23 = !DILocation(line: 116, column: 28, scope: !5) +!24 = !DILocation(line: 124, column: 25, scope: !5) +!25 = !DILocation(line: 124, column: 59, scope: !5) +!26 = !DILocation(line: 100, column: 58, scope: !5) +!27 = !DILocation(line: 128, column: 50, scope: !5) +!28 = !DILocation(line: 128, column: 37, scope: !5) +!29 = !DILocation(line: 128, column: 61, scope: !5) +!30 = !DILocation(line: 131, column: 9, scope: !5) +!31 = !DILocation(line: 132, column: 9, scope: !5) +!32 = !DILocation(line: 133, column: 10, scope: !5) +!33 = !DILocation(line: 136, column: 26, scope: !5) +!34 = !DILocation(line: 139, column: 14, scope: !5) +!35 = !DILocation(line: 139, column: 7, scope: !5) +!36 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !37) +!37 = !DILocation(line: 113, column: 34, scope: !5) +!38 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !37) +!39 = !DILocation(line: 140, column: 24, scope: !5) +!40 = !DILocation(line: 144, column: 29, scope: !5) +!41 = !DILocation(line: 144, column: 54, scope: !5) +!42 = !DILocation(line: 144, column: 44, scope: !5) +!43 = !DILocation(line: 145, column: 35, scope: !5) +!44 = !DILocation(line: 155, column: 83, scope: !5) +!45 = !DILocation(line: 158, column: 30, scope: !5) +!46 = !DILocation(line: 158, column: 52, scope: !5) +!47 = !DILocation(line: 158, column: 40, scope: !5) +!48 = !DILocation(line: 158, column: 63, scope: !5) +!49 = !DILocation(line: 159, column: 32, scope: !5) +!50 = !DILocation(line: 159, column: 55, scope: !5) +!51 = !DILocation(line: 159, column: 42, scope: !5) +!52 = !DILocation(line: 159, column: 66, scope: !5) +!53 = !DILocation(line: 161, column: 30, scope: !5) +!54 = !DILocation(line: 161, column: 35, scope: !5) +!55 = !DILocation(line: 161, column: 46, scope: !5) +!56 = !DILocation(line: 161, column: 56, scope: !5) +!57 = !DILocation(line: 163, column: 17, scope: !5) +!58 = !DILocation(line: 164, column: 19, scope: !5) +!59 = !DILocation(line: 167, column: 19, scope: !5) +!60 = !DILocation(line: 168, column: 21, scope: !5) +!61 = !DILocation(line: 169, column: 25, scope: !5) +!62 = !DILocation(line: 174, column: 36, scope: !5) +!63 = !DILocation(line: 175, column: 29, scope: !5) +!64 = !DILocation(line: 789, column: 38, scope: !65, inlinedAt: !66) +!65 = distinct !DILexicalBlockFile(scope: !5, file: !1, discriminator: 0) +!66 = !DILocation(line: 178, column: 107, scope: !5) +!67 = !DILocation(line: 789, column: 20, scope: !65, inlinedAt: !66) +!68 = !DILocation(line: 789, column: 56, scope: !65, inlinedAt: !66) +!69 = !DILocation(line: 789, column: 49, scope: !65, inlinedAt: !66) +!70 = !DILocation(line: 797, column: 52, scope: !65, inlinedAt: !66) +!71 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !66) +!72 = !DILocation(line: 789, column: 38, scope: !65, inlinedAt: !73) +!73 = !DILocation(line: 179, column: 111, scope: !5) +!74 = !DILocation(line: 789, column: 20, scope: !65, inlinedAt: !73) +!75 = !DILocation(line: 789, column: 49, scope: !65, inlinedAt: !73) +!76 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !73) +!77 = !DILocation(line: 188, column: 58, scope: !5) +!78 = !DILocation(line: 188, column: 34, scope: !5) +!79 = !DILocation(line: 188, column: 25, scope: !5) +!80 = !DILocation(line: 189, column: 33, scope: !5) +!81 = !DILocation(line: 189, column: 26, scope: !5) +!82 = !DILocation(line: 190, column: 30, scope: !5) +!83 = !DILocation(line: 190, column: 50, scope: !5) +!84 = !DILocation(line: 195, column: 30, scope: !5) +!85 = !DILocation(line: 196, column: 27, scope: !5) +!86 = !DILocation(line: 196, column: 41, scope: !5) +!87 = !DILocation(line: 197, column: 53, scope: !5) +!88 = !DILocation(line: 197, column: 39, scope: !5) +!89 = !DILocation(line: 199, column: 42, scope: !5) +!90 = !DILocation(line: 199, column: 29, scope: !5) +!91 = !DILocation(line: 390, column: 37, scope: !65, inlinedAt: !92) +!92 = !DILocation(line: 207, column: 12, scope: !5) +!93 = !DILocation(line: 390, column: 18, scope: !65, inlinedAt: !92) +!94 = !DILocation(line: 390, column: 49, scope: !65, inlinedAt: !92) +!95 = !DILocation(line: 391, column: 18, scope: !65, inlinedAt: !92) +!96 = !DILocation(line: 391, column: 49, scope: !65, inlinedAt: !92) +!97 = !DILocation(line: 395, column: 43, scope: !65, inlinedAt: !92) +!98 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !92) +!99 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !92) +!100 = !DILocation(line: 395, column: 101, scope: !65, inlinedAt: !92) +!101 = !DILocation(line: 395, column: 63, scope: !65, inlinedAt: !92) +!102 = !DILocation(line: 397, column: 28, scope: !65, inlinedAt: !92) +!103 = !DILocation(line: 795, column: 52, scope: !65, inlinedAt: !92) +!104 = !DILocation(line: 795, column: 23, scope: !65, inlinedAt: !92) +!105 = !DILocation(line: 414, column: 19, scope: !65, inlinedAt: !92) +!106 = !DILocation(line: 415, column: 19, scope: !65, inlinedAt: !92) +!107 = !DILocation(line: 417, column: 19, scope: !65, inlinedAt: !92) +!108 = !DILocation(line: 459, column: 19, scope: !65, inlinedAt: !92) +!109 = !DILocation(line: 762, column: 21, scope: !65, inlinedAt: !92) +!110 = !DILocation(line: 492, column: 91, scope: !65, inlinedAt: !92) +!111 = !DILocation(line: 481, column: 22, scope: !65, inlinedAt: !92) +!112 = !DILocation(line: 492, column: 79, scope: !65, inlinedAt: !92) +!113 = !DILocation(line: 492, column: 119, scope: !65, inlinedAt: !92) +!114 = !DILocation(line: 485, column: 23, scope: !65, inlinedAt: !92) +!115 = !DILocation(line: 484, column: 22, scope: !65, inlinedAt: !92) +!116 = !DILocation(line: 483, column: 23, scope: !65, inlinedAt: !92) +!117 = !DILocation(line: 487, column: 22, scope: !65, inlinedAt: !92) +!118 = !DILocation(line: 495, column: 25, scope: !65, inlinedAt: !92) +!119 = !DILocation(line: 513, column: 19, scope: !65, inlinedAt: !92) +!120 = !DILocation(line: 461, column: 14, scope: !65, inlinedAt: !92) +!121 = !DILocation(line: 486, column: 22, scope: !65, inlinedAt: !92) +!122 = !DILocation(line: 489, column: 23, scope: !65, inlinedAt: !92) +!123 = !DILocation(line: 494, column: 79, scope: !65, inlinedAt: !92) +!124 = !DILocation(line: 494, column: 91, scope: !65, inlinedAt: !92) +!125 = !DILocation(line: 494, column: 119, scope: !65, inlinedAt: !92) +!126 = !DILocation(line: 496, column: 24, scope: !65, inlinedAt: !92) +!127 = !DILocation(line: 497, column: 23, scope: !65, inlinedAt: !92) +!128 = !DILocation(line: 498, column: 23, scope: !65, inlinedAt: !92) +!129 = !DILocation(line: 503, column: 69, scope: !65, inlinedAt: !92) +!130 = !DILocation(line: 506, column: 27, scope: !65, inlinedAt: !92) +!131 = !DILocation(line: 507, column: 39, scope: !65, inlinedAt: !92) +!132 = !DILocation(line: 507, column: 21, scope: !65, inlinedAt: !92) +!133 = !DILocation(line: 512, column: 20, scope: !65, inlinedAt: !92) +!134 = !DILocation(line: 513, column: 14, scope: !65, inlinedAt: !92) +!135 = !DILocation(line: 533, column: 15, scope: !65, inlinedAt: !92) +!136 = !DILocation(line: 531, column: 43, scope: !65, inlinedAt: !92) +!137 = !DILocation(line: 535, column: 21, scope: !65, inlinedAt: !92) +!138 = !DILocation(line: 752, column: 33, scope: !65, inlinedAt: !92) +!139 = !DILocation(line: 753, column: 38, scope: !65, inlinedAt: !92) +!140 = !DILocation(line: 753, column: 24, scope: !65, inlinedAt: !92) +!141 = !DILocation(line: 754, column: 109, scope: !65, inlinedAt: !92) +!142 = !DILocation(line: 754, column: 113, scope: !65, inlinedAt: !92) +!143 = !DILocation(line: 754, column: 55, scope: !65, inlinedAt: !92) +!144 = !DILocation(line: 754, column: 25, scope: !65, inlinedAt: !92) +!145 = !DILocation(line: 755, column: 35, scope: !65, inlinedAt: !92) +!146 = !DILocation(line: 756, column: 34, scope: !65, inlinedAt: !92) +!147 = !DILocation(line: 756, column: 48, scope: !65, inlinedAt: !92) +!148 = !DILocation(line: 756, column: 63, scope: !65, inlinedAt: !92) +!149 = !DILocation(line: 757, column: 29, scope: !65, inlinedAt: !92) +!150 = !DILocation(line: 757, column: 61, scope: !65, inlinedAt: !92) +!151 = !DILocation(line: 757, column: 42, scope: !65, inlinedAt: !92) +!152 = !DILocation(line: 414, column: 28, scope: !65, inlinedAt: !92) +!153 = !DILocation(line: 214, column: 39, scope: !5) +!154 = !DILocation(line: 215, column: 31, scope: !5) +!155 = !DILocation(line: 215, column: 45, scope: !5) +!156 = !DILocation(line: 216, column: 62, scope: !5) +!157 = !DILocation(line: 216, column: 43, scope: !5) +!158 = !DILocation(line: 218, column: 33, scope: !5) +!159 = !DILocation(line: 390, column: 37, scope: !65, inlinedAt: !160) +!160 = !DILocation(line: 226, column: 16, scope: !5) +!161 = !DILocation(line: 390, column: 18, scope: !65, inlinedAt: !160) +!162 = !DILocation(line: 390, column: 49, scope: !65, inlinedAt: !160) +!163 = !DILocation(line: 391, column: 18, scope: !65, inlinedAt: !160) +!164 = !DILocation(line: 391, column: 49, scope: !65, inlinedAt: !160) +!165 = !DILocation(line: 395, column: 43, scope: !65, inlinedAt: !160) +!166 = !DILocation(line: 395, column: 63, scope: !65, inlinedAt: !160) +!167 = !DILocation(line: 397, column: 28, scope: !65, inlinedAt: !160) +!168 = !DILocation(line: 795, column: 52, scope: !65, inlinedAt: !160) +!169 = !DILocation(line: 795, column: 23, scope: !65, inlinedAt: !160) +!170 = !DILocation(line: 414, column: 19, scope: !65, inlinedAt: !160) +!171 = !DILocation(line: 415, column: 19, scope: !65, inlinedAt: !160) +!172 = !DILocation(line: 417, column: 19, scope: !65, inlinedAt: !160) +!173 = !DILocation(line: 459, column: 19, scope: !65, inlinedAt: !160) +!174 = !DILocation(line: 461, column: 14, scope: !65, inlinedAt: !160) +!175 = !DILocation(line: 506, column: 27, scope: !65, inlinedAt: !160) +!176 = !DILocation(line: 476, column: 79, scope: !65, inlinedAt: !160) +!177 = !DILocation(line: 507, column: 39, scope: !65, inlinedAt: !160) +!178 = !DILocation(line: 507, column: 21, scope: !65, inlinedAt: !160) +!179 = !DILocation(line: 512, column: 20, scope: !65, inlinedAt: !160) +!180 = !DILocation(line: 513, column: 19, scope: !65, inlinedAt: !160) +!181 = !DILocation(line: 513, column: 14, scope: !65, inlinedAt: !160) +!182 = !DILocation(line: 533, column: 15, scope: !65, inlinedAt: !160) +!183 = !DILocation(line: 520, column: 71, scope: !65, inlinedAt: !160) +!184 = !DILocation(line: 535, column: 21, scope: !65, inlinedAt: !160) +!185 = !DILocation(line: 752, column: 33, scope: !65, inlinedAt: !160) +!186 = !DILocation(line: 753, column: 38, scope: !65, inlinedAt: !160) +!187 = !DILocation(line: 753, column: 24, scope: !65, inlinedAt: !160) +!188 = !DILocation(line: 754, column: 109, scope: !65, inlinedAt: !160) +!189 = !DILocation(line: 754, column: 113, scope: !65, inlinedAt: !160) +!190 = !DILocation(line: 754, column: 55, scope: !65, inlinedAt: !160) +!191 = !DILocation(line: 754, column: 25, scope: !65, inlinedAt: !160) +!192 = !DILocation(line: 755, column: 35, scope: !65, inlinedAt: !160) +!193 = !DILocation(line: 756, column: 34, scope: !65, inlinedAt: !160) +!194 = !DILocation(line: 756, column: 48, scope: !65, inlinedAt: !160) +!195 = !DILocation(line: 756, column: 63, scope: !65, inlinedAt: !160) +!196 = !DILocation(line: 757, column: 29, scope: !65, inlinedAt: !160) +!197 = !DILocation(line: 757, column: 61, scope: !65, inlinedAt: !160) +!198 = !DILocation(line: 757, column: 42, scope: !65, inlinedAt: !160) +!199 = !DILocation(line: 414, column: 28, scope: !65, inlinedAt: !160) +!200 = !DILocation(line: 231, column: 24, scope: !5) +!201 = !DILocation(line: 231, column: 56, scope: !5) +!202 = !DILocation(line: 232, column: 14, scope: !5) +!203 = !DILocation(line: 236, column: 30, scope: !5) +!204 = !DILocation(line: 252, column: 25, scope: !5) +!205 = !DILocation(line: 253, column: 29, scope: !5) +!206 = !DILocation(line: 789, column: 38, scope: !65, inlinedAt: !207) +!207 = !DILocation(line: 256, column: 107, scope: !5) +!208 = !DILocation(line: 789, column: 20, scope: !65, inlinedAt: !207) +!209 = !DILocation(line: 789, column: 56, scope: !65, inlinedAt: !207) +!210 = !DILocation(line: 789, column: 49, scope: !65, inlinedAt: !207) +!211 = !DILocation(line: 797, column: 52, scope: !65, inlinedAt: !207) +!212 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !207) +!213 = !DILocation(line: 789, column: 20, scope: !65, inlinedAt: !214) +!214 = !DILocation(line: 257, column: 107, scope: !5) +!215 = !DILocation(line: 789, column: 49, scope: !65, inlinedAt: !214) +!216 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !214) +!217 = !DILocation(line: 263, column: 32, scope: !5) +!218 = !DILocation(line: 266, column: 56, scope: !5) +!219 = !DILocation(line: 267, column: 59, scope: !5) +!220 = !DILocation(line: 269, column: 34, scope: !5) +!221 = !DILocation(line: 282, column: 81, scope: !5) +!222 = !DILocation(line: 286, column: 32, scope: !5) +!223 = !DILocation(line: 287, column: 30, scope: !5) +!224 = !DILocation(line: 287, column: 43, scope: !5) +!225 = !DILocation(line: 288, column: 55, scope: !5) +!226 = !DILocation(line: 288, column: 42, scope: !5) +!227 = !DILocation(line: 290, column: 45, scope: !5) +!228 = !DILocation(line: 290, column: 32, scope: !5) +!229 = !DILocation(line: 583, column: 37, scope: !65, inlinedAt: !230) +!230 = !DILocation(line: 298, column: 16, scope: !5) +!231 = !DILocation(line: 584, column: 38, scope: !65, inlinedAt: !230) +!232 = !DILocation(line: 590, column: 42, scope: !65, inlinedAt: !230) +!233 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !230) +!234 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !230) +!235 = !DILocation(line: 590, column: 98, scope: !65, inlinedAt: !230) +!236 = !DILocation(line: 590, column: 61, scope: !65, inlinedAt: !230) +!237 = !DILocation(line: 762, column: 21, scope: !65, inlinedAt: !230) +!238 = !DILocation(line: 692, column: 91, scope: !65, inlinedAt: !230) +!239 = !DILocation(line: 306, column: 41, scope: !5) +!240 = !DILocation(line: 307, column: 34, scope: !5) +!241 = !DILocation(line: 307, column: 47, scope: !5) +!242 = !DILocation(line: 308, column: 64, scope: !5) +!243 = !DILocation(line: 308, column: 46, scope: !5) +!244 = !DILocation(line: 310, column: 36, scope: !5) +!245 = !DILocation(line: 583, column: 37, scope: !65, inlinedAt: !246) +!246 = !DILocation(line: 318, column: 20, scope: !5) +!247 = !DILocation(line: 584, column: 38, scope: !65, inlinedAt: !246) +!248 = !DILocation(line: 590, column: 42, scope: !65, inlinedAt: !246) +!249 = !DILocation(line: 590, column: 61, scope: !65, inlinedAt: !246) +!250 = !DILocation(line: 658, column: 20, scope: !65, inlinedAt: !230) +!251 = !DILocation(line: 262, column: 30, scope: !5) +!252 = !DILocation(line: 263, column: 51, scope: !5) +!253 = !DILocation(line: 266, column: 44, scope: !5) +!254 = !DILocation(line: 266, column: 67, scope: !5) +!255 = !DILocation(line: 267, column: 36, scope: !5) +!256 = !DILocation(line: 267, column: 46, scope: !5) +!257 = !DILocation(line: 267, column: 70, scope: !5) +!258 = !DILocation(line: 269, column: 50, scope: !5) +!259 = !DILocation(line: 269, column: 60, scope: !5) +!260 = !DILocation(line: 271, column: 21, scope: !5) +!261 = !DILocation(line: 272, column: 23, scope: !5) +!262 = !DILocation(line: 275, column: 25, scope: !5) +!263 = !DILocation(line: 276, column: 29, scope: !5) +!264 = !DILocation(line: 583, column: 18, scope: !65, inlinedAt: !230) +!265 = !DILocation(line: 583, column: 49, scope: !65, inlinedAt: !230) +!266 = !DILocation(line: 584, column: 19, scope: !65, inlinedAt: !230) +!267 = !DILocation(line: 584, column: 51, scope: !65, inlinedAt: !230) +!268 = !DILocation(line: 795, column: 23, scope: !65, inlinedAt: !230) +!269 = !DILocation(line: 656, column: 28, scope: !65, inlinedAt: !230) +!270 = !DILocation(line: 656, column: 22, scope: !65, inlinedAt: !230) +!271 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !230) +!272 = !DILocation(line: 712, column: 29, scope: !65, inlinedAt: !230) +!273 = !DILocation(line: 712, column: 21, scope: !65, inlinedAt: !230) +!274 = !DILocation(line: 608, column: 19, scope: !65, inlinedAt: !230) +!275 = !DILocation(line: 609, column: 19, scope: !65, inlinedAt: !230) +!276 = !DILocation(line: 592, column: 28, scope: !65, inlinedAt: !230) +!277 = !DILocation(line: 795, column: 52, scope: !65, inlinedAt: !230) +!278 = !DILocation(line: 657, column: 26, scope: !65, inlinedAt: !230) +!279 = !DILocation(line: 657, column: 46, scope: !65, inlinedAt: !230) +!280 = !DILocation(line: 660, column: 15, scope: !65, inlinedAt: !230) +!281 = !DILocation(line: 679, column: 24, scope: !65, inlinedAt: !230) +!282 = !DILocation(line: 683, column: 25, scope: !65, inlinedAt: !230) +!283 = !DILocation(line: 681, column: 25, scope: !65, inlinedAt: !230) +!284 = !DILocation(line: 682, column: 24, scope: !65, inlinedAt: !230) +!285 = !DILocation(line: 690, column: 79, scope: !65, inlinedAt: !230) +!286 = !DILocation(line: 690, column: 91, scope: !65, inlinedAt: !230) +!287 = !DILocation(line: 690, column: 119, scope: !65, inlinedAt: !230) +!288 = !DILocation(line: 693, column: 25, scope: !65, inlinedAt: !230) +!289 = !DILocation(line: 694, column: 24, scope: !65, inlinedAt: !230) +!290 = !DILocation(line: 696, column: 24, scope: !65, inlinedAt: !230) +!291 = !DILocation(line: 700, column: 69, scope: !65, inlinedAt: !230) +!292 = !DILocation(line: 703, column: 27, scope: !65, inlinedAt: !230) +!293 = !DILocation(line: 704, column: 40, scope: !65, inlinedAt: !230) +!294 = !DILocation(line: 704, column: 22, scope: !65, inlinedAt: !230) +!295 = !DILocation(line: 708, column: 24, scope: !65, inlinedAt: !230) +!296 = !DILocation(line: 708, column: 43, scope: !65, inlinedAt: !230) +!297 = !DILocation(line: 714, column: 20, scope: !65, inlinedAt: !230) +!298 = !DILocation(line: 715, column: 22, scope: !65, inlinedAt: !230) +!299 = !DILocation(line: 715, column: 16, scope: !65, inlinedAt: !230) +!300 = !DILocation(line: 739, column: 24, scope: !65, inlinedAt: !230) +!301 = !DILocation(line: 737, column: 45, scope: !65, inlinedAt: !230) +!302 = !DILocation(line: 739, column: 43, scope: !65, inlinedAt: !230) +!303 = !DILocation(line: 610, column: 19, scope: !65, inlinedAt: !230) +!304 = !DILocation(line: 752, column: 33, scope: !65, inlinedAt: !230) +!305 = !DILocation(line: 753, column: 38, scope: !65, inlinedAt: !230) +!306 = !DILocation(line: 753, column: 24, scope: !65, inlinedAt: !230) +!307 = !DILocation(line: 754, column: 109, scope: !65, inlinedAt: !230) +!308 = !DILocation(line: 754, column: 113, scope: !65, inlinedAt: !230) +!309 = !DILocation(line: 754, column: 55, scope: !65, inlinedAt: !230) +!310 = !DILocation(line: 754, column: 25, scope: !65, inlinedAt: !230) +!311 = !DILocation(line: 755, column: 35, scope: !65, inlinedAt: !230) +!312 = !DILocation(line: 756, column: 34, scope: !65, inlinedAt: !230) +!313 = !DILocation(line: 756, column: 48, scope: !65, inlinedAt: !230) +!314 = !DILocation(line: 756, column: 63, scope: !65, inlinedAt: !230) +!315 = !DILocation(line: 757, column: 29, scope: !65, inlinedAt: !230) +!316 = !DILocation(line: 757, column: 61, scope: !65, inlinedAt: !230) +!317 = !DILocation(line: 757, column: 42, scope: !65, inlinedAt: !230) +!318 = !DILocation(line: 608, column: 28, scope: !65, inlinedAt: !230) +!319 = !DILocation(line: 609, column: 28, scope: !65, inlinedAt: !230) +!320 = !DILocation(line: 656, column: 52, scope: !65, inlinedAt: !230) +!321 = !DILocation(line: 797, column: 52, scope: !65, inlinedAt: !230) +!322 = !DILocation(line: 583, column: 18, scope: !65, inlinedAt: !246) +!323 = !DILocation(line: 583, column: 49, scope: !65, inlinedAt: !246) +!324 = !DILocation(line: 584, column: 19, scope: !65, inlinedAt: !246) +!325 = !DILocation(line: 584, column: 51, scope: !65, inlinedAt: !246) +!326 = !DILocation(line: 795, column: 23, scope: !65, inlinedAt: !246) +!327 = !DILocation(line: 656, column: 28, scope: !65, inlinedAt: !246) +!328 = !DILocation(line: 656, column: 22, scope: !65, inlinedAt: !246) +!329 = !DILocation(line: 797, column: 23, scope: !65, inlinedAt: !246) +!330 = !DILocation(line: 712, column: 29, scope: !65, inlinedAt: !246) +!331 = !DILocation(line: 712, column: 21, scope: !65, inlinedAt: !246) +!332 = !DILocation(line: 608, column: 19, scope: !65, inlinedAt: !246) +!333 = !DILocation(line: 609, column: 19, scope: !65, inlinedAt: !246) +!334 = !DILocation(line: 592, column: 28, scope: !65, inlinedAt: !246) +!335 = !DILocation(line: 795, column: 52, scope: !65, inlinedAt: !246) +!336 = !DILocation(line: 657, column: 26, scope: !65, inlinedAt: !246) +!337 = !DILocation(line: 657, column: 46, scope: !65, inlinedAt: !246) +!338 = !DILocation(line: 658, column: 20, scope: !65, inlinedAt: !246) +!339 = !DILocation(line: 660, column: 15, scope: !65, inlinedAt: !246) +!340 = !DILocation(line: 703, column: 27, scope: !65, inlinedAt: !246) +!341 = !DILocation(line: 674, column: 78, scope: !65, inlinedAt: !246) +!342 = !DILocation(line: 704, column: 40, scope: !65, inlinedAt: !246) +!343 = !DILocation(line: 704, column: 22, scope: !65, inlinedAt: !246) +!344 = !DILocation(line: 708, column: 24, scope: !65, inlinedAt: !246) +!345 = !DILocation(line: 708, column: 43, scope: !65, inlinedAt: !246) +!346 = !DILocation(line: 714, column: 20, scope: !65, inlinedAt: !246) +!347 = !DILocation(line: 715, column: 22, scope: !65, inlinedAt: !246) +!348 = !DILocation(line: 715, column: 16, scope: !65, inlinedAt: !246) +!349 = !DILocation(line: 739, column: 24, scope: !65, inlinedAt: !246) +!350 = !DILocation(line: 723, column: 70, scope: !65, inlinedAt: !246) +!351 = !DILocation(line: 739, column: 43, scope: !65, inlinedAt: !246) +!352 = !DILocation(line: 752, column: 33, scope: !65, inlinedAt: !246) +!353 = !DILocation(line: 753, column: 38, scope: !65, inlinedAt: !246) +!354 = !DILocation(line: 753, column: 24, scope: !65, inlinedAt: !246) +!355 = !DILocation(line: 754, column: 109, scope: !65, inlinedAt: !246) +!356 = !DILocation(line: 754, column: 113, scope: !65, inlinedAt: !246) +!357 = !DILocation(line: 754, column: 55, scope: !65, inlinedAt: !246) +!358 = !DILocation(line: 754, column: 25, scope: !65, inlinedAt: !246) +!359 = !DILocation(line: 755, column: 35, scope: !65, inlinedAt: !246) +!360 = !DILocation(line: 756, column: 34, scope: !65, inlinedAt: !246) +!361 = !DILocation(line: 756, column: 48, scope: !65, inlinedAt: !246) +!362 = !DILocation(line: 756, column: 63, scope: !65, inlinedAt: !246) +!363 = !DILocation(line: 757, column: 29, scope: !65, inlinedAt: !246) +!364 = !DILocation(line: 757, column: 61, scope: !65, inlinedAt: !246) +!365 = !DILocation(line: 757, column: 42, scope: !65, inlinedAt: !246) +!366 = !DILocation(line: 608, column: 28, scope: !65, inlinedAt: !246) +!367 = !DILocation(line: 609, column: 28, scope: !65, inlinedAt: !246) +!368 = !DILocation(line: 610, column: 19, scope: !65, inlinedAt: !246) +!369 = !DILocation(line: 656, column: 52, scope: !65, inlinedAt: !246) +!370 = !DILocation(line: 797, column: 52, scope: !65, inlinedAt: !246) +!371 = !DILocation(line: 323, column: 23, scope: !5) +!372 = !DILocation(line: 323, column: 55, scope: !5) +!373 = !DILocation(line: 332, column: 30, scope: !5) +!374 = !DILocation(line: 334, column: 14, scope: !5) +!375 = !DILocation(line: 345, column: 55, scope: !5) +!376 = !DILocation(line: 345, column: 69, scope: !5) +!377 = !DILocation(line: 345, column: 29, scope: !5) +!378 = !DILocation(line: 345, column: 99, scope: !5) +!379 = !DILocation(line: 139, column: 4, scope: !5) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.source b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.source new file mode 100644 index 0000000000000000000000000000000000000000..667929325ef9f5065e3da65f13658006384c874f --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.source @@ -0,0 +1,2369 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":18:0) +#loc228 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":32:0) +#loc238 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":776:0) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":348:0) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":423:0) +#loc357 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":761:0) +#loc361 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":745:0) +#loc382 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":541:0) +#loc412 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":616:0) +#loc493 = loc("arg_Q"(#loc)) +#loc494 = loc("arg_K"(#loc)) +#loc495 = loc("arg_V"(#loc)) +#loc496 = loc("arg_LSE"(#loc)) +#loc497 = loc("arg_DELTA"(#loc)) +#loc498 = loc("arg_DO"(#loc)) +#loc499 = loc("arg_DQ"(#loc)) +#loc500 = loc("arg_DV"(#loc)) +#loc501 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc502 = loc("arg_KV_IDX"(#loc)) +#loc503 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc504 = loc("arg_Q_IDX"(#loc)) +#loc505 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc506 = loc("arg_FULL_KV_IDX"(#loc)) +#loc507 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc508 = loc("arg_FULL_Q_IDX"(#loc)) +#loc509 = loc("out_ptr0"(#loc)) +#loc510 = loc("ks0"(#loc)) +#loc511 = loc("ks1"(#loc)) +#loc512 = loc("ks2"(#loc)) +#loc513 = loc("ks3"(#loc)) +#loc514 = loc("ks4"(#loc)) +#loc515 = loc("ks5"(#loc)) +#loc516 = loc("ks6"(#loc)) +#loc517 = loc("ks7"(#loc)) +#loc700 = loc("x"(#loc228)) +#loc701 = loc("ptr"(#loc238)) +#loc702 = loc("offs_m"(#loc238)) +#loc703 = loc("offs_n"(#loc238)) +#loc704 = loc("stride_m"(#loc238)) +#loc705 = loc("stride_n"(#loc238)) +#loc706 = loc("M_LEN"(#loc238)) +#loc713 = loc("arg_Q"(#loc250)) +#loc714 = loc("arg_K"(#loc250)) +#loc715 = loc("arg_V"(#loc250)) +#loc716 = loc("arg_LSE"(#loc250)) +#loc717 = loc("arg_DELTA"(#loc250)) +#loc718 = loc("arg_DO"(#loc250)) +#loc719 = loc("arg_DQ"(#loc250)) +#loc720 = loc("arg_DV"(#loc250)) +#loc721 = loc("arg_KV_NUM_BLKS"(#loc250)) +#loc722 = loc("arg_KV_IDX"(#loc250)) +#loc723 = loc("arg_Q_NUM_BLKS"(#loc250)) +#loc724 = loc("arg_Q_IDX"(#loc250)) +#loc725 = loc("arg_FULL_KV_NUM_BLKS"(#loc250)) +#loc726 = loc("arg_FULL_KV_IDX"(#loc250)) +#loc727 = loc("arg_FULL_Q_NUM_BLKS"(#loc250)) +#loc728 = loc("arg_FULL_Q_IDX"(#loc250)) +#loc729 = loc("out_ptr0"(#loc250)) +#loc730 = loc("ks0"(#loc250)) +#loc731 = loc("ks1"(#loc250)) +#loc732 = loc("ks2"(#loc250)) +#loc733 = loc("ks3"(#loc250)) +#loc734 = loc("ks4"(#loc250)) +#loc735 = loc("ks5"(#loc250)) +#loc736 = loc("ks6"(#loc250)) +#loc737 = loc("ks7"(#loc250)) +#loc738 = loc("K"(#loc250)) +#loc739 = loc("V"(#loc250)) +#loc740 = loc("dq"(#loc250)) +#loc741 = loc("q"(#loc250)) +#loc742 = loc("do"(#loc250)) +#loc743 = loc("Di"(#loc250)) +#loc744 = loc("lse"(#loc250)) +#loc745 = loc("off_z"(#loc250)) +#loc746 = loc("off_hq"(#loc250)) +#loc747 = loc("offs_m2"(#loc250)) +#loc748 = loc("offs_n2"(#loc250)) +#loc749 = loc("stride_kn"(#loc250)) +#loc750 = loc("stride_kd"(#loc250)) +#loc751 = loc("stride_vn"(#loc250)) +#loc752 = loc("stride_vd"(#loc250)) +#loc753 = loc("kv_indices"(#loc250)) +#loc754 = loc("sparse_kv_num_blocks"(#loc250)) +#loc781 = loc("arg_Q"(#loc280)) +#loc782 = loc("arg_K"(#loc280)) +#loc783 = loc("arg_V"(#loc280)) +#loc784 = loc("arg_LSE"(#loc280)) +#loc785 = loc("arg_DELTA"(#loc280)) +#loc786 = loc("arg_DO"(#loc280)) +#loc787 = loc("arg_DQ"(#loc280)) +#loc788 = loc("arg_DV"(#loc280)) +#loc789 = loc("arg_KV_NUM_BLKS"(#loc280)) +#loc790 = loc("arg_KV_IDX"(#loc280)) +#loc791 = loc("arg_Q_NUM_BLKS"(#loc280)) +#loc792 = loc("arg_Q_IDX"(#loc280)) +#loc793 = loc("arg_FULL_KV_NUM_BLKS"(#loc280)) +#loc794 = loc("arg_FULL_KV_IDX"(#loc280)) +#loc795 = loc("arg_FULL_Q_NUM_BLKS"(#loc280)) +#loc796 = loc("arg_FULL_Q_IDX"(#loc280)) +#loc797 = loc("out_ptr0"(#loc280)) +#loc798 = loc("ks0"(#loc280)) +#loc799 = loc("ks1"(#loc280)) +#loc800 = loc("ks2"(#loc280)) +#loc801 = loc("ks3"(#loc280)) +#loc802 = loc("ks4"(#loc280)) +#loc803 = loc("ks5"(#loc280)) +#loc804 = loc("ks6"(#loc280)) +#loc805 = loc("ks7"(#loc280)) +#loc806 = loc("dq"(#loc280)) +#loc807 = loc("q"(#loc280)) +#loc808 = loc("kT_ptrs"(#loc280)) +#loc809 = loc("vT_ptrs"(#loc280)) +#loc810 = loc("do"(#loc280)) +#loc811 = loc("Di"(#loc280)) +#loc812 = loc("lse"(#loc280)) +#loc813 = loc("Q_LEN"(#loc280)) +#loc814 = loc("KV_LEN"(#loc280)) +#loc815 = loc("off_z"(#loc280)) +#loc816 = loc("off_hq"(#loc280)) +#loc817 = loc("offs_m2"(#loc280)) +#loc818 = loc("offs_n2"(#loc280)) +#loc819 = loc("offs_k"(#loc280)) +#loc820 = loc("offs_v"(#loc280)) +#loc821 = loc("stride_kn"(#loc280)) +#loc822 = loc("stride_kd"(#loc280)) +#loc823 = loc("stride_vn"(#loc280)) +#loc824 = loc("stride_vd"(#loc280)) +#loc825 = loc("kv_indices"(#loc280)) +#loc826 = loc("sparse_kv_num_blocks"(#loc280)) +#loc897 = loc("N_LEN"(#loc238)) +#loc898 = loc("indices"(#loc357)) +#loc899 = loc("max_len"(#loc357)) +#loc900 = loc("loop_iter"(#loc361)) +#loc901 = loc("col_indices"(#loc361)) +#loc902 = loc("total_blocks"(#loc361)) +#loc921 = loc("arg_Q"(#loc382)) +#loc922 = loc("arg_K"(#loc382)) +#loc923 = loc("arg_V"(#loc382)) +#loc924 = loc("arg_LSE"(#loc382)) +#loc925 = loc("arg_DELTA"(#loc382)) +#loc926 = loc("arg_DO"(#loc382)) +#loc927 = loc("arg_DQ"(#loc382)) +#loc928 = loc("arg_DV"(#loc382)) +#loc929 = loc("arg_KV_NUM_BLKS"(#loc382)) +#loc930 = loc("arg_KV_IDX"(#loc382)) +#loc931 = loc("arg_Q_NUM_BLKS"(#loc382)) +#loc932 = loc("arg_Q_IDX"(#loc382)) +#loc933 = loc("arg_FULL_KV_NUM_BLKS"(#loc382)) +#loc934 = loc("arg_FULL_KV_IDX"(#loc382)) +#loc935 = loc("arg_FULL_Q_NUM_BLKS"(#loc382)) +#loc936 = loc("arg_FULL_Q_IDX"(#loc382)) +#loc937 = loc("out_ptr0"(#loc382)) +#loc938 = loc("ks0"(#loc382)) +#loc939 = loc("ks1"(#loc382)) +#loc940 = loc("ks2"(#loc382)) +#loc941 = loc("ks3"(#loc382)) +#loc942 = loc("ks4"(#loc382)) +#loc943 = loc("ks5"(#loc382)) +#loc944 = loc("ks6"(#loc382)) +#loc945 = loc("ks7"(#loc382)) +#loc946 = loc("Q"(#loc382)) +#loc947 = loc("DO"(#loc382)) +#loc948 = loc("DELTA"(#loc382)) +#loc949 = loc("LSE"(#loc382)) +#loc950 = loc("dk"(#loc382)) +#loc951 = loc("dv"(#loc382)) +#loc952 = loc("k"(#loc382)) +#loc953 = loc("v"(#loc382)) +#loc954 = loc("off_z"(#loc382)) +#loc955 = loc("off_hq"(#loc382)) +#loc956 = loc("offs_n1"(#loc382)) +#loc957 = loc("offs_m1"(#loc382)) +#loc958 = loc("stride_qm"(#loc382)) +#loc959 = loc("stride_qd"(#loc382)) +#loc960 = loc("stride_dom"(#loc382)) +#loc961 = loc("stride_dod"(#loc382)) +#loc962 = loc("q_indices"(#loc382)) +#loc963 = loc("sparse_q_num_blocks"(#loc382)) +#loc989 = loc("arg_Q"(#loc412)) +#loc990 = loc("arg_K"(#loc412)) +#loc991 = loc("arg_V"(#loc412)) +#loc992 = loc("arg_LSE"(#loc412)) +#loc993 = loc("arg_DELTA"(#loc412)) +#loc994 = loc("arg_DO"(#loc412)) +#loc995 = loc("arg_DQ"(#loc412)) +#loc996 = loc("arg_DV"(#loc412)) +#loc997 = loc("arg_KV_NUM_BLKS"(#loc412)) +#loc998 = loc("arg_KV_IDX"(#loc412)) +#loc999 = loc("arg_Q_NUM_BLKS"(#loc412)) +#loc1000 = loc("arg_Q_IDX"(#loc412)) +#loc1001 = loc("arg_FULL_KV_NUM_BLKS"(#loc412)) +#loc1002 = loc("arg_FULL_KV_IDX"(#loc412)) +#loc1003 = loc("arg_FULL_Q_NUM_BLKS"(#loc412)) +#loc1004 = loc("arg_FULL_Q_IDX"(#loc412)) +#loc1005 = loc("out_ptr0"(#loc412)) +#loc1006 = loc("ks0"(#loc412)) +#loc1007 = loc("ks1"(#loc412)) +#loc1008 = loc("ks2"(#loc412)) +#loc1009 = loc("ks3"(#loc412)) +#loc1010 = loc("ks4"(#loc412)) +#loc1011 = loc("ks5"(#loc412)) +#loc1012 = loc("ks6"(#loc412)) +#loc1013 = loc("ks7"(#loc412)) +#loc1014 = loc("dk"(#loc412)) +#loc1015 = loc("dv"(#loc412)) +#loc1016 = loc("qT_ptrs"(#loc412)) +#loc1017 = loc("k"(#loc412)) +#loc1018 = loc("v"(#loc412)) +#loc1019 = loc("do_ptrs"(#loc412)) +#loc1020 = loc("DELTA"(#loc412)) +#loc1021 = loc("LSE"(#loc412)) +#loc1022 = loc("Q_LEN"(#loc412)) +#loc1023 = loc("KV_LEN"(#loc412)) +#loc1024 = loc("off_z"(#loc412)) +#loc1025 = loc("off_hq"(#loc412)) +#loc1026 = loc("offs_n1"(#loc412)) +#loc1027 = loc("offs_m1"(#loc412)) +#loc1028 = loc("offs_k"(#loc412)) +#loc1029 = loc("offs_v"(#loc412)) +#loc1030 = loc("stride_qm"(#loc412)) +#loc1031 = loc("stride_qd"(#loc412)) +#loc1032 = loc("stride_dom"(#loc412)) +#loc1033 = loc("stride_dod"(#loc412)) +#loc1034 = loc("q_indices"(#loc412)) +#loc1035 = loc("sparse_q_num_blocks"(#loc412)) +module { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc)), %ks2: i32 loc("ks2"(#loc)), %ks3: i32 loc("ks3"(#loc)), %ks4: i32 loc("ks4"(#loc)), %ks5: i32 loc("ks5"(#loc)), %ks6: i32 loc("ks6"(#loc)), %ks7: i32 loc("ks7"(#loc))) attributes {noinline = false} { + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c4096_i32_0 = arith.constant 4096 : i32 loc(#loc1) + %0 = arith.muli %c4096_i32_0, %ks0 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc2) + %c4096_i32_1 = arith.constant 4096 : i32 loc(#loc2) + %c1_i32 = arith.constant 1 : i32 loc(#loc2) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc3) + %c1024_i32_2 = arith.constant 1024 : i32 loc(#loc3) + %1 = arith.muli %c1024_i32_2, %ks1 : i32 loc(#loc3) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc4) + %c1024_i32_4 = arith.constant 1024 : i32 loc(#loc4) + %c1_i32_5 = arith.constant 1 : i32 loc(#loc4) + %c1024_i32_6 = arith.constant 1024 : i32 loc(#loc5) + %c1024_i32_7 = arith.constant 1024 : i32 loc(#loc5) + %2 = arith.muli %c1024_i32_7, %ks1 : i32 loc(#loc5) + %c128_i32_8 = arith.constant 128 : i32 loc(#loc6) + %c1024_i32_9 = arith.constant 1024 : i32 loc(#loc6) + %c1_i32_10 = arith.constant 1 : i32 loc(#loc6) + %c1_i32_11 = arith.constant 1 : i32 loc(#loc7) + %3 = arith.cmpi sge, %c1_i32_11, %ks0 : i32 loc(#loc7) + %c1_i32_12 = arith.constant 1 : i32 loc(#loc8) + %c1_i32_13 = arith.constant 1 : i32 loc(#loc8) + %4 = arith.extui %3 : i1 to i32 loc(#loc8) + %5 = arith.muli %c1_i32_13, %4 : i32 loc(#loc8) + %c1_i32_14 = arith.constant 1 : i32 loc(#loc9) + %6 = arith.cmpi sgt, %ks0, %c1_i32_14 : i32 loc(#loc9) + %7 = arith.extui %6 : i1 to i32 loc(#loc10) + %8 = arith.muli %ks0, %7 : i32 loc(#loc10) + %9 = arith.addi %5, %8 : i32 loc(#loc11) + %c4096_i32_15 = arith.constant 4096 : i32 loc(#loc12) + %c4096_i32_16 = arith.constant 4096 : i32 loc(#loc12) + %10 = arith.muli %c4096_i32_16, %9 : i32 loc(#loc12) + %c1_i32_17 = arith.constant 1 : i32 loc(#loc13) + %11 = arith.cmpi sge, %c1_i32_17, %ks0 : i32 loc(#loc13) + %c1_i32_18 = arith.constant 1 : i32 loc(#loc14) + %c1_i32_19 = arith.constant 1 : i32 loc(#loc14) + %12 = arith.extui %11 : i1 to i32 loc(#loc14) + %13 = arith.muli %c1_i32_19, %12 : i32 loc(#loc14) + %c1_i32_20 = arith.constant 1 : i32 loc(#loc15) + %14 = arith.cmpi sgt, %ks0, %c1_i32_20 : i32 loc(#loc15) + %15 = arith.extui %14 : i1 to i32 loc(#loc16) + %16 = arith.muli %ks0, %15 : i32 loc(#loc16) + %17 = arith.addi %13, %16 : i32 loc(#loc17) + %c128_i32_21 = arith.constant 128 : i32 loc(#loc18) + %c128_i32_22 = arith.constant 128 : i32 loc(#loc18) + %18 = arith.muli %c128_i32_22, %17 : i32 loc(#loc18) + %c128_i32_23 = arith.constant 128 : i32 loc(#loc19) + %c1_i32_24 = arith.constant 1 : i32 loc(#loc19) + %c4096_i32_25 = arith.constant 4096 : i32 loc(#loc20) + %c4096_i32_26 = arith.constant 4096 : i32 loc(#loc20) + %19 = arith.muli %c4096_i32_26, %ks0 : i32 loc(#loc20) + %c128_i32_27 = arith.constant 128 : i32 loc(#loc21) + %c4096_i32_28 = arith.constant 4096 : i32 loc(#loc21) + %c1_i32_29 = arith.constant 1 : i32 loc(#loc21) + %c1024_i32_30 = arith.constant 1024 : i32 loc(#loc22) + %c1024_i32_31 = arith.constant 1024 : i32 loc(#loc22) + %20 = arith.muli %c1024_i32_31, %ks1 : i32 loc(#loc22) + %c128_i32_32 = arith.constant 128 : i32 loc(#loc23) + %c1024_i32_33 = arith.constant 1024 : i32 loc(#loc23) + %c1_i32_34 = arith.constant 1 : i32 loc(#loc23) + %ZQ = arith.constant 1 : i32 loc(#loc518) + %HQ = arith.constant 32 : i32 loc(#loc519) + %HKV = arith.constant 8 : i32 loc(#loc520) + %ZKV = arith.constant 1 : i32 loc(#loc521) + %pid = tt.get_program_id x : i32 loc(#loc522) + %NUM_KV_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%ks1) : (i32) -> i32 loc(#loc523) + %NUM_Q_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%ks0) : (i32) -> i32 loc(#loc524) + %off_zq = tt.get_program_id y : i32 loc(#loc525) + %off_hkv = tt.get_program_id z : i32 loc(#loc526) + %off_zkv = arith.remsi %off_zq, %ZKV : i32 loc(#loc527) + %SPARSE_Z = arith.constant 1 : i32 loc(#loc528) + %SPARSE_HQ = arith.constant 1 : i32 loc(#loc529) + %sparse_idx_z = arith.remsi %off_zq, %SPARSE_Z : i32 loc(#loc530) + %k_adj = arith.muli %c128_i32_3, %off_hkv : i32 loc(#loc531) + %k_adj_35 = arith.muli %1, %off_zkv : i32 loc(#loc532) + %k_adj_36 = arith.addi %k_adj, %k_adj_35 : i32 loc(#loc533) + %k_adj_37 = arith.extsi %k_adj_36 : i32 to i64 loc(#loc534) + %v_adj = arith.muli %c128_i32_8, %off_hkv : i32 loc(#loc535) + %v_adj_38 = arith.muli %2, %off_zkv : i32 loc(#loc536) + %v_adj_39 = arith.addi %v_adj, %v_adj_38 : i32 loc(#loc537) + %v_adj_40 = arith.extsi %v_adj_39 : i32 to i64 loc(#loc538) + %dv_adj = arith.muli %c128_i32_32, %off_hkv : i32 loc(#loc539) + %dv_adj_41 = arith.muli %20, %off_zq : i32 loc(#loc540) + %dv_adj_42 = arith.addi %dv_adj, %dv_adj_41 : i32 loc(#loc541) + %dv_adj_43 = arith.extsi %dv_adj_42 : i32 to i64 loc(#loc542) + %K = tt.addptr %arg_K, %k_adj_37 : !tt.ptr, i64 loc(#loc543) + %V = tt.addptr %arg_V, %v_adj_40 : !tt.ptr, i64 loc(#loc544) + %DV = tt.addptr %arg_DV, %dv_adj_43 : !tt.ptr, i64 loc(#loc545) + %RCP_LN2 = arith.constant 1.44269502 : f32 loc(#loc546) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc547) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc548) + %21 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS : i32 loc(#loc55) + %22:2 = scf.if %21 -> (i32, i32) { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS : i32 loc(#loc549) + %SPARSE_Q_MULTIPLE = arith.constant 1 : i32 loc(#loc1114) + %SPARSE_KV_MULTIPLE = arith.constant 2 : i32 loc(#loc1115) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc552) + %off_hq2_44 = arith.constant 4 : i32 loc(#loc553) + %off_hq2_45 = arith.constant 4 : i32 loc(#loc553) + %off_hq2_46 = arith.muli %off_hkv, %off_hq2_45 : i32 loc(#loc553) + %off_hq2_47 = arith.addi %off_hq2, %off_hq2_46 : i32 loc(#loc554) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc555) + %off_pid_mask = arith.divsi %start_m2_block, %SPARSE_Q_MULTIPLE : i32 loc(#loc556) + %stride_kv_idx_h = arith.muli %ks3, %ks4 : i32 loc(#loc557) + %sparse_idx_hq2 = arith.remsi %off_hq2_47, %SPARSE_HQ : i32 loc(#loc558) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc559) + %sparse_hz_offset_48 = arith.addi %sparse_hz_offset, %sparse_idx_hq2 : i32 loc(#loc560) + %sparse_kv_num_blks_offset = arith.muli %sparse_hz_offset_48, %ks2 : i32 loc(#loc561) + %sparse_kv_num_blks_offset_49 = arith.addi %sparse_kv_num_blks_offset, %off_pid_mask : i32 loc(#loc562) + %sparse_kv_idx_offset = arith.muli %sparse_hz_offset_48, %stride_kv_idx_h : i32 loc(#loc563) + %sparse_kv_idx_offset_50 = arith.muli %off_pid_mask, %ks4 : i32 loc(#loc564) + %sparse_kv_idx_offset_51 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_50 : i32 loc(#loc565) + %q_adj2 = arith.muli %c128_i32, %off_hq2_47 : i32 loc(#loc566) + %q_adj2_52 = arith.muli %0, %off_zq : i32 loc(#loc567) + %q_adj2_53 = arith.addi %q_adj2, %q_adj2_52 : i32 loc(#loc568) + %q_adj2_54 = arith.extsi %q_adj2_53 : i32 to i64 loc(#loc569) + %do_adj2 = arith.muli %18, %off_hq2_47 : i32 loc(#loc570) + %do_adj2_55 = arith.muli %10, %off_zq : i32 loc(#loc571) + %do_adj2_56 = arith.addi %do_adj2, %do_adj2_55 : i32 loc(#loc572) + %do_adj2_57 = arith.extsi %do_adj2_56 : i32 to i64 loc(#loc573) + %dq_adj2 = arith.muli %c128_i32_27, %off_hq2_47 : i32 loc(#loc574) + %dq_adj2_58 = arith.muli %19, %off_zq : i32 loc(#loc575) + %dq_adj2_59 = arith.addi %dq_adj2, %dq_adj2_58 : i32 loc(#loc576) + %dq_adj2_60 = arith.extsi %dq_adj2_59 : i32 to i64 loc(#loc577) + %off_chz2 = arith.muli %off_zq, %HQ : i32 loc(#loc578) + %off_chz2_61 = arith.addi %off_chz2, %off_hq2_47 : i32 loc(#loc579) + %off_chz2_62 = arith.muli %off_chz2_61, %ks0 : i32 loc(#loc580) + %off_chz2_63 = arith.extsi %off_chz2_62 : i32 to i64 loc(#loc581) + %Q2 = tt.addptr %arg_Q, %q_adj2_54 : !tt.ptr, i64 loc(#loc582) + %DO2 = tt.addptr %arg_DO, %do_adj2_57 : !tt.ptr, i64 loc(#loc583) + %DQ2 = tt.addptr %arg_DQ, %dq_adj2_60 : !tt.ptr, i64 loc(#loc584) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_63 : !tt.ptr, i64 loc(#loc585) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_63 : !tt.ptr, i64 loc(#loc586) + %dq = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc587) + %start_m2 = arith.constant 128 : i32 loc(#loc588) + %start_m2_64 = arith.constant 128 : i32 loc(#loc588) + %start_m2_65 = arith.muli %start_m2_block, %start_m2_64 : i32 loc(#loc588) + %offs_m2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc589) + %offs_m2_66 = tt.splat %start_m2_65 : i32 -> tensor<128xi32> loc(#loc590) + %offs_m2_67 = arith.addi %offs_m2_66, %offs_m2 : tensor<128xi32> loc(#loc590) + %q = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%Q2, %offs_m2_67, %offs_k, %c4096_i32_1, %c1_i32, %ks0) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc591) + %do = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%DO2, %offs_m2_67, %offs_v, %c128_i32_23, %c1_i32_24, %ks0) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc592) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc593) + %Di_68 = arith.cmpi slt, %offs_m2_67, %Di : tensor<128xi32> loc(#loc593) + %Di_69 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc594) + %Di_70 = tt.addptr %Di_69, %offs_m2_67 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc594) + %Di_71 = tt.load %Di_70, %Di_68 : tensor<128x!tt.ptr> loc(#loc595) + %lse = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc596) + %lse_72 = arith.cmpi slt, %offs_m2_67, %lse : tensor<128xi32> loc(#loc596) + %lse_73 = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc597) + %lse_74 = tt.addptr %lse_73, %offs_m2_67 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc597) + %lse_75 = tt.load %lse_74, %lse_72 : tensor<128x!tt.ptr> loc(#loc598) + %lse_76 = arith.constant 0xFF800000 : f32 loc(#loc599) + %lse_77 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc599) + %lse_78 = arith.cmpf oeq, %lse_75, %lse_77 : tensor<128xf32> loc(#loc599) + %lse_79 = arith.constant 0.000000e+00 : f32 loc(#loc600) + %lse_80 = arith.constant 0.000000e+00 : f32 loc(#loc600) + %lse_81 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc600) + %lse_82 = arith.select %lse_78, %lse_81, %lse_75 : tensor<128xi1>, tensor<128xf32> loc(#loc600) + %lse_83 = tt.expand_dims %lse_82 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc601) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_51 : !tt.ptr, i32 loc(#loc602) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc603) + %kv_start_84 = arith.constant 128 : i32 loc(#loc604) + %kv_start_85 = arith.constant 128 : i32 loc(#loc604) + %kv_start_86 = arith.muli %kv_start, %kv_start_85 : i32 loc(#loc604) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_49 : !tt.ptr, i32 loc(#loc605) + %sparse_kv_num_blocks_87 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc606) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc607) + %offs_n2_88 = tt.splat %kv_start_86 : i32 -> tensor<64xi32> loc(#loc608) + %offs_n2_89 = arith.addi %offs_n2_88, %offs_n2 : tensor<64xi32> loc(#loc608) + %dq_90 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(42,)cconstexpr_bf16__(43,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %K, %V, %dq, %q, %do, %Di_71, %lse_83, %off_zq, %off_hq2_47, %offs_m2_67, %offs_n2_89, %c1024_i32_4, %c1_i32_5, %c1024_i32_9, %c1_i32_10, %kv_indices, %sparse_kv_num_blocks_87) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc609) + %kv_indices_91 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_51 : !tt.ptr, i32 loc(#loc610) + %kv_start_92 = tt.load %kv_indices_91 : !tt.ptr loc(#loc611) + %kv_start_93 = arith.constant 128 : i32 loc(#loc612) + %kv_start_94 = arith.constant 128 : i32 loc(#loc612) + %kv_start_95 = arith.muli %kv_start_92, %kv_start_94 : i32 loc(#loc612) + %sparse_kv_num_blocks_96 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_49 : !tt.ptr, i32 loc(#loc613) + %sparse_kv_num_blocks_97 = tt.load %sparse_kv_num_blocks_96 : !tt.ptr loc(#loc614) + %offs_n2_98 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc615) + %offs_n2_99 = tt.splat %kv_start_95 : i32 -> tensor<64xi32> loc(#loc616) + %offs_n2_100 = arith.addi %offs_n2_99, %offs_n2_98 : tensor<64xi32> loc(#loc616) + %dq_101 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(42,)cconstexpr_bf16__(43,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %K, %V, %dq_90, %q, %do, %Di_71, %lse_83, %off_zq, %off_hq2_47, %offs_m2_67, %offs_n2_100, %c1024_i32_4, %c1_i32_5, %c1024_i32_9, %c1_i32_10, %kv_indices_91, %sparse_kv_num_blocks_97) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc617) + %dq_ptrs = tt.expand_dims %offs_m2_67 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc618) + %dq_ptrs_102 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc619) + %dq_ptrs_103 = arith.muli %dq_ptrs, %dq_ptrs_102 : tensor<128x1xi32> loc(#loc619) + %dq_ptrs_104 = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc620) + %dq_ptrs_105 = tt.addptr %dq_ptrs_104, %dq_ptrs_103 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc620) + %dq_ptrs_106 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc621) + %dq_ptrs_107 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc622) + %dq_ptrs_108 = arith.muli %dq_ptrs_106, %dq_ptrs_107 : tensor<1x128xi32> loc(#loc622) + %dq_ptrs_109 = tt.broadcast %dq_ptrs_105 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc623) + %dq_ptrs_110 = tt.broadcast %dq_ptrs_108 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc623) + %dq_ptrs_111 = tt.addptr %dq_ptrs_109, %dq_ptrs_110 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc623) + %dq_112 = arith.constant 0.0883883461 : f32 loc(#loc624) + %dq_113 = arith.constant 0.0883883461 : f32 loc(#loc624) + %dq_114 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc624) + %dq_115 = arith.mulf %dq_101, %dq_114 : tensor<128x128xf32> loc(#loc624) + %23 = tt.expand_dims %offs_m2_67 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc133) + %24 = tt.splat %ks0 : i32 -> tensor<128x1xi32> loc(#loc134) + %25 = arith.cmpi slt, %23, %24 : tensor<128x1xi32> loc(#loc134) + %26 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc135) + %c128_i32_116 = arith.constant 128 : i32 loc(#loc136) + %cst = arith.constant dense<128> : tensor<1x128xi32> loc(#loc136) + %27 = arith.cmpi slt, %26, %cst : tensor<1x128xi32> loc(#loc136) + %28 = tt.broadcast %25 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc137) + %29 = tt.broadcast %27 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc137) + %30 = arith.andi %28, %29 : tensor<128x128xi1> loc(#loc137) + %31 = arith.truncf %dq_115 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc138) + tt.store %dq_ptrs_111, %31, %30 : tensor<128x128x!tt.ptr> loc(#loc138) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc138) + } else { + %SPARSE_Q_MULTIPLE = arith.constant 2 : i32 loc(#loc1116) + %SPARSE_KV_MULTIPLE = arith.constant 1 : i32 loc(#loc1117) + %pid_mask = arith.divsi %pid, %SPARSE_KV_MULTIPLE : i32 loc(#loc627) + %stride_q_idx_h = arith.muli %ks6, %ks7 : i32 loc(#loc628) + %dv = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc629) + %dk = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc630) + %start_n1 = arith.constant 128 : i32 loc(#loc631) + %start_n1_44 = arith.constant 128 : i32 loc(#loc631) + %start_n1_45 = arith.muli %pid, %start_n1_44 : i32 loc(#loc631) + %offs_n1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc632) + %offs_n1_46 = tt.splat %start_n1_45 : i32 -> tensor<128xi32> loc(#loc633) + %offs_n1_47 = arith.addi %offs_n1_46, %offs_n1 : tensor<128xi32> loc(#loc633) + %k = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%K, %offs_n1_47, %offs_k, %c1024_i32_4, %c1_i32_5, %ks1) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc634) + %v = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%V, %offs_n1_47, %offs_v, %c1024_i32_9, %c1_i32_10, %ks1) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc635) + %c0_i32 = arith.constant 0 : i32 loc(#loc150) + %c4_i32 = arith.constant 4 : i32 loc(#loc150) + %c1_i32_48 = arith.constant 1 : i32 loc(#loc150) + %23 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc150) + %24 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc150) + %25 = arith.bitcast %c1_i32_48 : i32 to i32 loc(#loc150) + %26 = ub.poison : i32 loc(#loc150) + %dk_49:2 = scf.for %off_g = %23 to %24 step %25 iter_args(%dv_89 = %dv, %dk_90 = %dk) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.constant 4 : i32 loc(#loc637) + %off_hq1_91 = arith.constant 4 : i32 loc(#loc637) + %off_hq1_92 = arith.muli %off_hkv, %off_hq1_91 : i32 loc(#loc637) + %off_hq1_93 = arith.addi %off_hq1_92, %off_g : i32 loc(#loc638) + %q_adj1 = arith.muli %c128_i32, %off_hq1_93 : i32 loc(#loc639) + %q_adj1_94 = arith.muli %0, %off_zq : i32 loc(#loc640) + %q_adj1_95 = arith.addi %q_adj1, %q_adj1_94 : i32 loc(#loc641) + %q_adj1_96 = arith.extsi %q_adj1_95 : i32 to i64 loc(#loc642) + %do_adj1 = arith.muli %18, %off_hq1_93 : i32 loc(#loc643) + %do_adj1_97 = arith.muli %10, %off_zq : i32 loc(#loc644) + %do_adj1_98 = arith.addi %do_adj1, %do_adj1_97 : i32 loc(#loc645) + %do_adj1_99 = arith.extsi %do_adj1_98 : i32 to i64 loc(#loc646) + %dq_adj1 = arith.muli %c128_i32_27, %off_hq1_93 : i32 loc(#loc647) + %dq_adj1_100 = arith.muli %19, %off_zq : i32 loc(#loc648) + %dq_adj1_101 = arith.addi %dq_adj1, %dq_adj1_100 : i32 loc(#loc649) + %dq_adj1_102 = arith.extsi %dq_adj1_101 : i32 to i64 loc(#loc650) + %off_chz1 = arith.muli %off_zq, %HQ : i32 loc(#loc651) + %off_chz1_103 = arith.addi %off_chz1, %off_hq1_93 : i32 loc(#loc652) + %off_chz1_104 = arith.muli %off_chz1_103, %ks0 : i32 loc(#loc653) + %off_chz1_105 = arith.extsi %off_chz1_104 : i32 to i64 loc(#loc654) + %Q1 = tt.addptr %arg_Q, %q_adj1_96 : !tt.ptr, i64 loc(#loc655) + %DO1 = tt.addptr %arg_DO, %do_adj1_99 : !tt.ptr, i64 loc(#loc656) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_105 : !tt.ptr, i64 loc(#loc657) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_105 : !tt.ptr, i64 loc(#loc658) + %sparse_idx_hq1 = arith.remsi %off_hq1_93, %SPARSE_HQ : i32 loc(#loc659) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc660) + %sparse_hz_offset_106 = arith.addi %sparse_hz_offset, %sparse_idx_hq1 : i32 loc(#loc661) + %sparse_q_num_blks_offset = arith.muli %sparse_hz_offset_106, %ks5 : i32 loc(#loc662) + %sparse_q_num_blks_offset_107 = arith.addi %sparse_q_num_blks_offset, %pid_mask : i32 loc(#loc663) + %sparse_q_idx_offset = arith.muli %sparse_hz_offset_106, %stride_q_idx_h : i32 loc(#loc664) + %sparse_q_idx_offset_108 = arith.muli %pid_mask, %ks6 : i32 loc(#loc665) + %sparse_q_idx_offset_109 = arith.addi %sparse_q_idx_offset, %sparse_q_idx_offset_108 : i32 loc(#loc666) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset_109 : !tt.ptr, i32 loc(#loc667) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc668) + %q_start_110 = arith.constant 128 : i32 loc(#loc669) + %q_start_111 = arith.constant 128 : i32 loc(#loc669) + %q_start_112 = arith.muli %q_start, %q_start_111 : i32 loc(#loc669) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %sparse_q_num_blks_offset_107 : !tt.ptr, i32 loc(#loc670) + %sparse_q_num_blocks_113 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc671) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc672) + %offs_m1_114 = tt.splat %q_start_112 : i32 -> tensor<64xi32> loc(#loc673) + %offs_m1_115 = arith.addi %offs_m1_114, %offs_m1 : tensor<64xi32> loc(#loc673) + %45:2 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(43,)cconstexpr_bf16__(44,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %Q1, %DO1, %DELTA1, %LSE1, %dk_90, %dv_89, %k, %v, %off_zq, %off_hq1_93, %offs_n1_47, %offs_m1_115, %c4096_i32_1, %c1_i32, %c128_i32_23, %c1_i32_24, %q_indices, %sparse_q_num_blocks_113) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc188) + %q_indices_116 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset_109 : !tt.ptr, i32 loc(#loc674) + %q_start_117 = tt.load %q_indices_116 : !tt.ptr loc(#loc675) + %q_start_118 = arith.constant 128 : i32 loc(#loc676) + %q_start_119 = arith.constant 128 : i32 loc(#loc676) + %q_start_120 = arith.muli %q_start_117, %q_start_119 : i32 loc(#loc676) + %sparse_q_num_blocks_121 = tt.addptr %arg_FULL_Q_NUM_BLKS, %sparse_q_num_blks_offset_107 : !tt.ptr, i32 loc(#loc677) + %sparse_q_num_blocks_122 = tt.load %sparse_q_num_blocks_121 : !tt.ptr loc(#loc678) + %offs_m1_123 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc679) + %offs_m1_124 = tt.splat %q_start_120 : i32 -> tensor<64xi32> loc(#loc680) + %offs_m1_125 = arith.addi %offs_m1_124, %offs_m1_123 : tensor<64xi32> loc(#loc680) + %46:2 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(43,)cconstexpr_bf16__(44,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %Q1, %DO1, %DELTA1, %LSE1, %45#0, %45#1, %k, %v, %off_zq, %off_hq1_93, %offs_n1_47, %offs_m1_125, %c4096_i32_1, %c1_i32, %c128_i32_23, %c1_i32_24, %q_indices_116, %sparse_q_num_blocks_122) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc196) + scf.yield %46#1, %46#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc197) + } loc(#loc1118) + %dv_ptrs = tt.expand_dims %offs_n1_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc681) + %dv_ptrs_50 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc682) + %dv_ptrs_51 = arith.muli %dv_ptrs, %dv_ptrs_50 : tensor<128x1xi32> loc(#loc682) + %dv_ptrs_52 = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc683) + %dv_ptrs_53 = tt.addptr %dv_ptrs_52, %dv_ptrs_51 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc683) + %dv_ptrs_54 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc684) + %dv_ptrs_55 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc685) + %dv_ptrs_56 = arith.muli %dv_ptrs_54, %dv_ptrs_55 : tensor<1x128xi32> loc(#loc685) + %dv_ptrs_57 = tt.broadcast %dv_ptrs_53 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc686) + %dv_ptrs_58 = tt.broadcast %dv_ptrs_56 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc686) + %dv_ptrs_59 = tt.addptr %dv_ptrs_57, %dv_ptrs_58 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc686) + %index_n = tt.expand_dims %offs_n1_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc687) + %index_k = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc688) + %index_v = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc689) + %27 = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc207) + %28 = arith.cmpi slt, %index_n, %27 : tensor<128x1xi32> loc(#loc207) + %c128_i32_60 = arith.constant 128 : i32 loc(#loc208) + %cst = arith.constant dense<128> : tensor<1x128xi32> loc(#loc208) + %29 = arith.cmpi slt, %index_v, %cst : tensor<1x128xi32> loc(#loc208) + %30 = tt.broadcast %28 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc209) + %31 = tt.broadcast %29 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc209) + %32 = arith.andi %30, %31 : tensor<128x128xi1> loc(#loc209) + %33 = arith.truncf %dk_49#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc210) + tt.store %dv_ptrs_59, %33, %32 : tensor<128x128x!tt.ptr> loc(#loc210) + %dk_61 = arith.constant 0.0883883461 : f32 loc(#loc690) + %dk_62 = arith.constant 0.0883883461 : f32 loc(#loc690) + %dk_63 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc690) + %dk_64 = arith.mulf %dk_49#1, %dk_63 : tensor<128x128xf32> loc(#loc690) + %mask = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc691) + %mask_65 = arith.cmpi slt, %index_n, %mask : tensor<128x1xi32> loc(#loc691) + %xindex = arith.constant 128 : i32 loc(#loc692) + %xindex_66 = arith.constant 128 : i32 loc(#loc692) + %xindex_67 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc692) + %xindex_68 = arith.muli %xindex_67, %index_n : tensor<128x1xi32> loc(#loc692) + %xindex_69 = tt.broadcast %index_k : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc693) + %xindex_70 = tt.broadcast %xindex_68 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc693) + %xindex_71 = arith.addi %xindex_69, %xindex_70 : tensor<128x128xi32> loc(#loc693) + %xindex_72 = arith.constant 128 : i32 loc(#loc694) + %xindex_73 = arith.constant 128 : i32 loc(#loc694) + %xindex_74 = arith.muli %xindex_73, %off_hkv : i32 loc(#loc694) + %xindex_75 = arith.muli %xindex_74, %ks1 : i32 loc(#loc695) + %xindex_76 = tt.splat %xindex_75 : i32 -> tensor<128x128xi32> loc(#loc696) + %xindex_77 = arith.addi %xindex_71, %xindex_76 : tensor<128x128xi32> loc(#loc696) + %xindex_78 = arith.constant 1024 : i32 loc(#loc697) + %xindex_79 = arith.constant 1024 : i32 loc(#loc697) + %xindex_80 = arith.muli %xindex_79, %off_zq : i32 loc(#loc697) + %xindex_81 = arith.muli %xindex_80, %ks1 : i32 loc(#loc698) + %xindex_82 = tt.splat %xindex_81 : i32 -> tensor<128x128xi32> loc(#loc699) + %xindex_83 = arith.addi %xindex_77, %xindex_82 : tensor<128x128xi32> loc(#loc699) + %c128_i32_84 = arith.constant 128 : i32 loc(#loc221) + %c128_i32_85 = arith.constant 128 : i32 loc(#loc221) + %34 = arith.muli %c128_i32_85, %off_hkv : i32 loc(#loc221) + %35 = tt.splat %34 : i32 -> tensor<1x128xi32> loc(#loc222) + %36 = arith.addi %index_k, %35 : tensor<1x128xi32> loc(#loc222) + %c1024_i32_86 = arith.constant 1024 : i32 loc(#loc223) + %c1024_i32_87 = arith.constant 1024 : i32 loc(#loc223) + %cst_88 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc223) + %37 = arith.muli %cst_88, %index_n : tensor<128x1xi32> loc(#loc223) + %38 = tt.broadcast %36 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc224) + %39 = tt.broadcast %37 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc224) + %40 = arith.addi %38, %39 : tensor<128x128xi32> loc(#loc224) + %41 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc225) + %42 = tt.addptr %41, %40 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc225) + %43 = tt.broadcast %mask_65 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc226) + %44 = arith.truncf %dk_64 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc226) + tt.store %42, %44, %43 : tensor<128x128x!tt.ptr> loc(#loc226) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc226) + } loc(#loc56) + tt.return loc(#loc227) + } loc(#loc) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%x: i32 loc("x"(#loc228))) -> i32 attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 loc(#loc229) + %c128_i32_0 = arith.constant 128 : i32 loc(#loc229) + %0 = arith.addi %x, %c128_i32_0 : i32 loc(#loc229) + %c1_i32 = arith.constant 1 : i32 loc(#loc230) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc230) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc230) + %c128_i32_2 = arith.constant 128 : i32 loc(#loc231) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc231) + %2 = arith.divsi %1, %c128_i32_3 : i32 loc(#loc231) + tt.return %2 : i32 loc(#loc232) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc233) + tt.return %3 : i32 loc(#loc233) + } loc(#loc228) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() -> tensor<128x128xf32> attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 loc(#loc235) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc235) + tt.return %cst_0 : tensor<128x128xf32> loc(#loc236) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc237) + tt.return %0 : tensor<128x128xf32> loc(#loc237) + } loc(#loc234) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: !tt.ptr loc("ptr"(#loc238)), %offs_m: tensor<128xi32> loc("offs_m"(#loc238)), %offs_n: tensor<128xi32> loc("offs_n"(#loc238)), %stride_m: i32 loc("stride_m"(#loc238)), %stride_n: i32 loc("stride_n"(#loc238)), %M_LEN: i32 loc("M_LEN"(#loc238))) -> tensor<128x128xbf16> attributes {noinline = false} { + %ptr_0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc707) + %ptr_1 = tt.splat %stride_m : i32 -> tensor<128x1xi32> loc(#loc708) + %ptr_2 = arith.muli %ptr_0, %ptr_1 : tensor<128x1xi32> loc(#loc708) + %ptr_3 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc709) + %ptr_4 = tt.addptr %ptr_3, %ptr_2 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc709) + %ptr_5 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc710) + %ptr_6 = tt.splat %stride_n : i32 -> tensor<1x128xi32> loc(#loc711) + %ptr_7 = arith.muli %ptr_5, %ptr_6 : tensor<1x128xi32> loc(#loc711) + %ptr_8 = tt.broadcast %ptr_4 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc712) + %ptr_9 = tt.broadcast %ptr_7 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc712) + %ptr_10 = tt.addptr %ptr_8, %ptr_9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc712) + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc245) + %1 = tt.splat %M_LEN : i32 -> tensor<128x1xi32> loc(#loc246) + %2 = arith.cmpi slt, %0, %1 : tensor<128x1xi32> loc(#loc246) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc247) + %3 = tt.broadcast %2 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc247) + %cst_11 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc247) + %4 = arith.truncf %cst_11 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc247) + %5 = tt.load %ptr_10, %3, %4 : tensor<128x128x!tt.ptr> loc(#loc247) + tt.return %5 : tensor<128x128xbf16> loc(#loc248) + ^bb1: // no predecessors + %6 = ub.poison : tensor<128x128xbf16> loc(#loc249) + tt.return %6 : tensor<128x128xbf16> loc(#loc249) + } loc(#loc238) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(42,)cconstexpr_bf16__(43,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc250)), %arg_K: !tt.ptr loc("arg_K"(#loc250)), %arg_V: !tt.ptr loc("arg_V"(#loc250)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc250)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc250)), %arg_DO: !tt.ptr loc("arg_DO"(#loc250)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc250)), %arg_DV: !tt.ptr loc("arg_DV"(#loc250)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc250)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc250)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc250)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc250)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc250)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc250)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc250)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc250)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc250)), %ks0: i32 loc("ks0"(#loc250)), %ks1: i32 loc("ks1"(#loc250)), %ks2: i32 loc("ks2"(#loc250)), %ks3: i32 loc("ks3"(#loc250)), %ks4: i32 loc("ks4"(#loc250)), %ks5: i32 loc("ks5"(#loc250)), %ks6: i32 loc("ks6"(#loc250)), %ks7: i32 loc("ks7"(#loc250)), %K: !tt.ptr loc("K"(#loc250)), %V: !tt.ptr loc("V"(#loc250)), %dq: tensor<128x128xf32> loc("dq"(#loc250)), %q: tensor<128x128xbf16> loc("q"(#loc250)), %do: tensor<128x128xbf16> loc("do"(#loc250)), %Di: tensor<128xf32> loc("Di"(#loc250)), %lse: tensor<128x1xf32> loc("lse"(#loc250)), %off_z: i32 loc("off_z"(#loc250)), %off_hq: i32 loc("off_hq"(#loc250)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc250)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc250)), %stride_kn: i32 loc("stride_kn"(#loc250)), %stride_kd: i32 loc("stride_kd"(#loc250)), %stride_vn: i32 loc("stride_vn"(#loc250)), %stride_vd: i32 loc("stride_vd"(#loc250)), %kv_indices: !tt.ptr loc("kv_indices"(#loc250)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc250))) -> tensor<128x128xf32> attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc755) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc756) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc757) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc758) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc758) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc759) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc759) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc760) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc761) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc761) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc762) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc762) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc762) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc763) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc764) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc764) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc765) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc765) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc766) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc767) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc767) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc768) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc768) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc768) + %hi = arith.constant 2 : i32 loc(#loc769) + %hi_20 = arith.constant 2 : i32 loc(#loc769) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc769) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc770) + %hi_23 = arith.constant 1 : i32 loc(#loc771) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc771) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc772) + %c0_i32 = arith.constant 0 : i32 loc(#loc269) + %c1_i32 = arith.constant 1 : i32 loc(#loc269) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc269) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc269) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc269) + %3 = ub.poison : i32 loc(#loc269) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(46,)cconstexpr_bf16__(47,)cconstexpr_1_d_44269504__(48,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %ks0, %ks1, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc774) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc775) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc776) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc777) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc777) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc778) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc779) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc779) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc780) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc780) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc277) + } loc(#loc1123) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc278) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc279) + tt.return %4 : tensor<128x128xf32> loc(#loc279) + } loc(#loc250) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%x: i32 loc("x"(#loc228))) -> i32 attributes {noinline = false} { + %c64_i32 = arith.constant 64 : i32 loc(#loc229) + %c64_i32_0 = arith.constant 64 : i32 loc(#loc229) + %0 = arith.addi %x, %c64_i32_0 : i32 loc(#loc229) + %c1_i32 = arith.constant 1 : i32 loc(#loc230) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc230) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc230) + %c64_i32_2 = arith.constant 64 : i32 loc(#loc231) + %c64_i32_3 = arith.constant 64 : i32 loc(#loc231) + %2 = arith.divsi %1, %c64_i32_3 : i32 loc(#loc231) + tt.return %2 : i32 loc(#loc232) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc233) + tt.return %3 : i32 loc(#loc233) + } loc(#loc228) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(46,)cconstexpr_bf16__(47,)cconstexpr_1_d_44269504__(48,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc280)), %arg_K: !tt.ptr loc("arg_K"(#loc280)), %arg_V: !tt.ptr loc("arg_V"(#loc280)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc280)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc280)), %arg_DO: !tt.ptr loc("arg_DO"(#loc280)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc280)), %arg_DV: !tt.ptr loc("arg_DV"(#loc280)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc280)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc280)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc280)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc280)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc280)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc280)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc280)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc280)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc280)), %ks0: i32 loc("ks0"(#loc280)), %ks1: i32 loc("ks1"(#loc280)), %ks2: i32 loc("ks2"(#loc280)), %ks3: i32 loc("ks3"(#loc280)), %ks4: i32 loc("ks4"(#loc280)), %ks5: i32 loc("ks5"(#loc280)), %ks6: i32 loc("ks6"(#loc280)), %ks7: i32 loc("ks7"(#loc280)), %dq: tensor<128x128xf32> loc("dq"(#loc280)), %q: tensor<128x128xbf16> loc("q"(#loc280)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc280)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc280)), %do: tensor<128x128xbf16> loc("do"(#loc280)), %Di: tensor<128xf32> loc("Di"(#loc280)), %lse: tensor<128x1xf32> loc("lse"(#loc280)), %Q_LEN: i32 loc("Q_LEN"(#loc280)), %KV_LEN: i32 loc("KV_LEN"(#loc280)), %off_z: i32 loc("off_z"(#loc280)), %off_hq: i32 loc("off_hq"(#loc280)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc280)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc280)), %offs_k: tensor<128xi32> loc("offs_k"(#loc280)), %offs_v: tensor<128xi32> loc("offs_v"(#loc280)), %stride_kn: i32 loc("stride_kn"(#loc280)), %stride_kd: i32 loc("stride_kd"(#loc280)), %stride_vn: i32 loc("stride_vn"(#loc280)), %stride_vd: i32 loc("stride_vd"(#loc280)), %kv_indices: !tt.ptr loc("kv_indices"(#loc280)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc280))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc827) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc828) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc828) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc828) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc829) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc829) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc829) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc829) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc830) + %n_6 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S1_64S_i32__(%n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc831) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc832) + %m_7 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S128_1S_i32__(%m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc833) + %post_mod_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc834) + %post_mod_scores_8 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc835) + %post_mod_scores_9 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_8 : tensor<1x64xi32> loc(#loc835) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc836) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc836) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc836) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc836) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_5, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc836) + %tmp2 = arith.constant 0 : i32 loc(#loc837) + %tmp2_15 = arith.constant dense<0> : tensor<1xi32> loc(#loc837) + %tmp3 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc838) + %tmp3_16 = tt.broadcast %tmp3 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc838) + %tmp3_17 = arith.cmpi slt, %m_7, %tmp3_16 : tensor<128x1xi32> loc(#loc838) + %tmp5 = tt.broadcast %n_6 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc839) + %tmp5_18 = tt.broadcast %m_7 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc839) + %tmp5_19 = arith.cmpi sle, %tmp5, %tmp5_18 : tensor<128x64xi32> loc(#loc839) + %tmp6 = tt.broadcast %tmp3_17 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc840) + %tmp6_20 = arith.andi %tmp6, %tmp5_19 : tensor<128x64xi1> loc(#loc840) + %tmp7 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc841) + %tmp7_21 = tt.broadcast %tmp7 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc841) + %tmp7_22 = arith.cmpi sge, %m_7, %tmp7_21 : tensor<128x1xi32> loc(#loc841) + %tmp8 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc842) + %tmp8_23 = tt.broadcast %tmp8 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc842) + %tmp8_24 = arith.cmpi slt, %n_6, %tmp8_23 : tensor<1x64xi32> loc(#loc842) + %tmp9 = tt.broadcast %tmp7_22 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc843) + %tmp9_25 = tt.broadcast %tmp8_24 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc843) + %tmp9_26 = arith.andi %tmp9, %tmp9_25 : tensor<128x64xi1> loc(#loc843) + %tmp10 = arith.constant 0 : i32 loc(#loc844) + %tmp10_27 = arith.extui %tmp8_24 : tensor<1x64xi1> to tensor<1x64xi32> loc(#loc844) + %tmp10_28 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc844) + %tmp10_29 = arith.cmpi eq, %tmp10_27, %tmp10_28 : tensor<1x64xi32> loc(#loc844) + %tmp11 = tt.broadcast %tmp7_22 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc845) + %tmp11_30 = tt.broadcast %tmp10_29 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc845) + %tmp11_31 = arith.andi %tmp11, %tmp11_30 : tensor<128x64xi1> loc(#loc845) + %tmp12 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc846) + %tmp12_32 = tt.broadcast %tmp12 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc846) + %tmp12_33 = arith.subi %m_7, %tmp12_32 : tensor<128x1xi32> loc(#loc846) + %tmp13 = arith.constant 16 : i32 loc(#loc847) + %tmp13_34 = arith.constant dense<16> : tensor<1xi32> loc(#loc847) + %tmp14 = arith.constant 0 : i32 loc(#loc848) + %tmp14_35 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc848) + %tmp14_36 = arith.cmpi slt, %tmp12_33, %tmp14_35 : tensor<128x1xi32> loc(#loc848) + %tmp14_37 = arith.constant 0 : i32 loc(#loc849) + %tmp14_38 = arith.constant dense<0> : tensor<1xi32> loc(#loc849) + %tmp14_39 = arith.cmpi slt, %tmp13_34, %tmp14_38 : tensor<1xi32> loc(#loc849) + %tmp14_40 = tt.expand_dims %tmp14_39 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc850) + %tmp14_41 = tt.broadcast %tmp14_40 : tensor<1x1xi1> -> tensor<128x1xi1> loc(#loc850) + %tmp14_42 = arith.cmpi ne, %tmp14_36, %tmp14_41 : tensor<128x1xi1> loc(#loc850) + %tmp14_43 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc851) + %tmp14_44 = tt.broadcast %tmp14_43 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc851) + %tmp14_45 = arith.remsi %tmp12_33, %tmp14_44 : tensor<128x1xi32> loc(#loc851) + %tmp14_46 = arith.constant 0 : i32 loc(#loc852) + %tmp14_47 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc852) + %tmp14_48 = arith.cmpi ne, %tmp14_45, %tmp14_47 : tensor<128x1xi32> loc(#loc852) + %tmp14_49 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc853) + %tmp14_50 = tt.broadcast %tmp14_49 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc853) + %tmp14_51 = arith.divsi %tmp12_33, %tmp14_50 : tensor<128x1xi32> loc(#loc853) + %tmp14_52 = arith.constant 1 : i32 loc(#loc854) + %tmp14_53 = arith.constant 1 : i32 loc(#loc854) + %tmp14_54 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc854) + %tmp14_55 = arith.subi %tmp14_51, %tmp14_54 : tensor<128x1xi32> loc(#loc854) + %tmp14_56 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc855) + %tmp14_57 = tt.broadcast %tmp14_56 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc855) + %tmp14_58 = arith.divsi %tmp12_33, %tmp14_57 : tensor<128x1xi32> loc(#loc855) + %tmp14_59 = arith.select %tmp14_48, %tmp14_55, %tmp14_58 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc856) + %tmp14_60 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc857) + %tmp14_61 = tt.broadcast %tmp14_60 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc857) + %tmp14_62 = arith.divsi %tmp12_33, %tmp14_61 : tensor<128x1xi32> loc(#loc857) + %tmp14_63 = arith.select %tmp14_42, %tmp14_59, %tmp14_62 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc858) + %tmp15 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc859) + %tmp15_64 = tt.broadcast %tmp15 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc859) + %tmp15_65 = arith.subi %n_6, %tmp15_64 : tensor<1x64xi32> loc(#loc859) + %tmp16 = arith.constant 0 : i32 loc(#loc860) + %tmp16_66 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc860) + %tmp16_67 = arith.cmpi slt, %tmp15_65, %tmp16_66 : tensor<1x64xi32> loc(#loc860) + %tmp16_68 = arith.constant 0 : i32 loc(#loc861) + %tmp16_69 = arith.constant dense<0> : tensor<1xi32> loc(#loc861) + %tmp16_70 = arith.cmpi slt, %tmp13_34, %tmp16_69 : tensor<1xi32> loc(#loc861) + %tmp16_71 = tt.expand_dims %tmp16_70 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc862) + %tmp16_72 = tt.broadcast %tmp16_71 : tensor<1x1xi1> -> tensor<1x64xi1> loc(#loc862) + %tmp16_73 = arith.cmpi ne, %tmp16_67, %tmp16_72 : tensor<1x64xi1> loc(#loc862) + %tmp16_74 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc863) + %tmp16_75 = tt.broadcast %tmp16_74 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc863) + %tmp16_76 = arith.remsi %tmp15_65, %tmp16_75 : tensor<1x64xi32> loc(#loc863) + %tmp16_77 = arith.constant 0 : i32 loc(#loc864) + %tmp16_78 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc864) + %tmp16_79 = arith.cmpi ne, %tmp16_76, %tmp16_78 : tensor<1x64xi32> loc(#loc864) + %tmp16_80 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc865) + %tmp16_81 = tt.broadcast %tmp16_80 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc865) + %tmp16_82 = arith.divsi %tmp15_65, %tmp16_81 : tensor<1x64xi32> loc(#loc865) + %tmp16_83 = arith.constant 1 : i32 loc(#loc866) + %tmp16_84 = arith.constant 1 : i32 loc(#loc866) + %tmp16_85 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc866) + %tmp16_86 = arith.subi %tmp16_82, %tmp16_85 : tensor<1x64xi32> loc(#loc866) + %tmp16_87 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc867) + %tmp16_88 = tt.broadcast %tmp16_87 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc867) + %tmp16_89 = arith.divsi %tmp15_65, %tmp16_88 : tensor<1x64xi32> loc(#loc867) + %tmp16_90 = arith.select %tmp16_79, %tmp16_86, %tmp16_89 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc868) + %tmp16_91 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc869) + %tmp16_92 = tt.broadcast %tmp16_91 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc869) + %tmp16_93 = arith.divsi %tmp15_65, %tmp16_92 : tensor<1x64xi32> loc(#loc869) + %tmp16_94 = arith.select %tmp16_73, %tmp16_90, %tmp16_93 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc870) + %tmp17 = tt.broadcast %tmp14_63 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc871) + %tmp17_95 = tt.broadcast %tmp16_94 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc871) + %tmp17_96 = arith.cmpi eq, %tmp17, %tmp17_95 : tensor<128x64xi32> loc(#loc871) + %tmp18 = arith.andi %tmp11_31, %tmp17_96 : tensor<128x64xi1> loc(#loc872) + %tmp19 = arith.ori %tmp9_26, %tmp18 : tensor<128x64xi1> loc(#loc873) + %tmp20 = arith.ori %tmp6_20, %tmp19 : tensor<128x64xi1> loc(#loc874) + %post_mod_scores_97 = arith.constant 0xFF800000 : f32 loc(#loc875) + %post_mod_scores_98 = arith.constant 0xFF800000 : f32 loc(#loc875) + %post_mod_scores_99 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc875) + %post_mod_scores_100 = arith.select %tmp20, %post_mod_scores_14, %post_mod_scores_99 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc875) + %post_mod_scores_101 = arith.constant 1.44269502 : f32 loc(#loc876) + %post_mod_scores_102 = arith.constant 1.44269502 : f32 loc(#loc876) + %post_mod_scores_103 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc876) + %post_mod_scores_104 = arith.mulf %post_mod_scores_100, %post_mod_scores_103 : tensor<128x64xf32> loc(#loc876) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc877) + %p_105 = arith.subf %post_mod_scores_104, %p : tensor<128x64xf32> loc(#loc877) + %p_106 = math.exp2 %p_105 : tensor<128x64xf32> loc(#loc878) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc879) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc880) + %dp_107 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc880) + %dp_108 = tt.dot %do, %vT, %dp_107, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc880) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc881) + %ds_109 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc882) + %ds_110 = arith.subf %dp_108, %ds_109 : tensor<128x64xf32> loc(#loc882) + %ds_111 = arith.mulf %p_106, %ds_110 : tensor<128x64xf32> loc(#loc883) + %grad_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc884) + %grad_scores_112 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc885) + %grad_scores_113 = arith.cmpi slt, %grad_scores, %grad_scores_112 : tensor<1x64xi32> loc(#loc885) + %grad_scores_114 = arith.constant 0.000000e+00 : f32 loc(#loc886) + %grad_scores_115 = arith.constant 0.000000e+00 : f32 loc(#loc886) + %grad_scores_116 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc886) + %grad_scores_117 = tt.broadcast %grad_scores_113 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc886) + %grad_scores_118 = arith.select %grad_scores_117, %ds_111, %grad_scores_116 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc886) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc887) + %scatter_mask_119 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc888) + %scatter_mask_120 = arith.cmpi slt, %scatter_mask, %scatter_mask_119 : tensor<128x1xi32> loc(#loc888) + %scatter_mask_121 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc889) + %scatter_mask_122 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc890) + %scatter_mask_123 = arith.cmpi slt, %scatter_mask_121, %scatter_mask_122 : tensor<1x64xi32> loc(#loc890) + %scatter_mask_124 = tt.broadcast %scatter_mask_120 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc891) + %scatter_mask_125 = tt.broadcast %scatter_mask_123 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc891) + %scatter_mask_126 = arith.andi %scatter_mask_124, %scatter_mask_125 : tensor<128x64xi1> loc(#loc891) + %ds_127 = arith.constant 0.000000e+00 : f32 loc(#loc892) + %ds_128 = arith.constant 0.000000e+00 : f32 loc(#loc892) + %ds_129 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc892) + %ds_130 = arith.select %tmp20, %grad_scores_118, %ds_129 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc892) + %ds_131 = arith.truncf %ds_130 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc893) + %dq_132 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc894) + %dq_133 = arith.constant 0.000000e+00 : f32 loc(#loc895) + %dq_134 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc895) + %dq_135 = tt.dot %ds_131, %dq_132, %dq_134, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc895) + %dq_136 = arith.addf %dq, %dq_135 : tensor<128x128xf32> loc(#loc896) + tt.return %dq_136 : tensor<128x128xf32> loc(#loc351) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc352) + tt.return %0 : tensor<128x128xf32> loc(#loc352) + } loc(#loc280) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%ptr: tensor<128x64x!tt.ptr> loc("ptr"(#loc238)), %offs_m: tensor<128xi32> loc("offs_m"(#loc238)), %offs_n: tensor<64xi32> loc("offs_n"(#loc238)), %N_LEN: i32 loc("N_LEN"(#loc238))) -> tensor<128x64xbf16> attributes {noinline = false} { + %0 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc353) + %1 = tt.splat %N_LEN : i32 -> tensor<1x64xi32> loc(#loc354) + %2 = arith.cmpi slt, %0, %1 : tensor<1x64xi32> loc(#loc354) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc355) + %3 = tt.broadcast %2 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc355) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc355) + %4 = arith.truncf %cst_0 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc355) + %5 = tt.load %ptr, %3, %4 : tensor<128x64x!tt.ptr> loc(#loc355) + tt.return %5 : tensor<128x64xbf16> loc(#loc356) + ^bb1: // no predecessors + %6 = ub.poison : tensor<128x64xbf16> loc(#loc249) + tt.return %6 : tensor<128x64xbf16> loc(#loc249) + } loc(#loc238) + tt.func private @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S1_64S_i32__(%indices: tensor<1x64xi32> loc("indices"(#loc357)), %max_len: i32 loc("max_len"(#loc357))) -> tensor<1x64xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<1x64xi32> loc(#loc358) + %1 = arith.remsi %indices, %0 : tensor<1x64xi32> loc(#loc358) + tt.return %1 : tensor<1x64xi32> loc(#loc359) + ^bb1: // no predecessors + %2 = ub.poison : tensor<1x64xi32> loc(#loc360) + tt.return %2 : tensor<1x64xi32> loc(#loc360) + } loc(#loc357) + tt.func private @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S128_1S_i32__(%indices: tensor<128x1xi32> loc("indices"(#loc357)), %max_len: i32 loc("max_len"(#loc357))) -> tensor<128x1xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<128x1xi32> loc(#loc358) + %1 = arith.remsi %indices, %0 : tensor<128x1xi32> loc(#loc358) + tt.return %1 : tensor<128x1xi32> loc(#loc359) + ^bb1: // no predecessors + %2 = ub.poison : tensor<128x1xi32> loc(#loc360) + tt.return %2 : tensor<128x1xi32> loc(#loc360) + } loc(#loc357) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%loop_iter: i32 loc("loop_iter"(#loc361)), %col_indices: !tt.ptr loc("col_indices"(#loc361)), %total_blocks: i32 loc("total_blocks"(#loc361))) -> i32 attributes {noinline = false} { + %cur_block_idx = arith.constant 2 : i32 loc(#loc903) + %cur_block_idx_0 = arith.constant 2 : i32 loc(#loc903) + %cur_block_idx_1 = arith.divsi %loop_iter, %cur_block_idx_0 : i32 loc(#loc903) + %cur_block = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc904) + %cur_block_2 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc905) + %next_block = arith.constant 1 : i32 loc(#loc906) + %next_block_3 = arith.constant 1 : i32 loc(#loc906) + %next_block_4 = arith.addi %cur_block_idx_1, %next_block_3 : i32 loc(#loc906) + %next_block_5 = arith.cmpi slt, %next_block_4, %total_blocks : i32 loc(#loc907) + %next_block_6 = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc908) + %next_block_7 = arith.constant 1 : i32 loc(#loc909) + %next_block_8 = tt.addptr %next_block_6, %next_block_7 : !tt.ptr, i32 loc(#loc909) + %next_block_9 = tt.load %next_block_8, %next_block_5 evictionPolicy = evict_last : !tt.ptr loc(#loc910) + %needs_jump = arith.constant 1 : i32 loc(#loc911) + %needs_jump_10 = arith.constant 1 : i32 loc(#loc911) + %needs_jump_11 = arith.addi %loop_iter, %needs_jump_10 : i32 loc(#loc911) + %needs_jump_12 = arith.constant 2 : i32 loc(#loc912) + %needs_jump_13 = arith.constant 2 : i32 loc(#loc912) + %needs_jump_14 = arith.remsi %needs_jump_11, %needs_jump_13 : i32 loc(#loc912) + %needs_jump_15 = arith.constant 0 : i32 loc(#loc913) + %needs_jump_16 = arith.cmpi eq, %needs_jump_14, %needs_jump_15 : i32 loc(#loc913) + %jump_to_block = arith.subi %next_block_9, %cur_block_2 : i32 loc(#loc914) + %jump_to_block_17 = arith.constant 128 : i32 loc(#loc915) + %jump_to_block_18 = arith.constant 128 : i32 loc(#loc915) + %jump_to_block_19 = arith.muli %jump_to_block, %jump_to_block_18 : i32 loc(#loc915) + %jump_to_block_20 = arith.constant 64 : i32 loc(#loc916) + %jump_to_block_21 = arith.constant 64 : i32 loc(#loc916) + %jump_to_block_22 = arith.subi %jump_to_block_19, %jump_to_block_21 : i32 loc(#loc916) + %offset = arith.extui %needs_jump_16 : i1 to i32 loc(#loc917) + %offset_23 = arith.muli %jump_to_block_22, %offset : i32 loc(#loc917) + %offset_24 = arith.constant 1 : i32 loc(#loc918) + %offset_25 = arith.constant 1 : i32 loc(#loc918) + %offset_26 = arith.extui %needs_jump_16 : i1 to i32 loc(#loc918) + %offset_27 = arith.subi %offset_25, %offset_26 : i32 loc(#loc918) + %offset_28 = arith.constant 64 : i32 loc(#loc919) + %offset_29 = arith.constant 64 : i32 loc(#loc919) + %offset_30 = arith.muli %offset_27, %offset_29 : i32 loc(#loc919) + %offset_31 = arith.addi %offset_23, %offset_30 : i32 loc(#loc920) + tt.return %offset_31 : i32 loc(#loc380) + ^bb1: // no predecessors + %0 = ub.poison : i32 loc(#loc381) + tt.return %0 : i32 loc(#loc381) + } loc(#loc361) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(42,)cconstexpr_bf16__(43,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc250)), %arg_K: !tt.ptr loc("arg_K"(#loc250)), %arg_V: !tt.ptr loc("arg_V"(#loc250)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc250)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc250)), %arg_DO: !tt.ptr loc("arg_DO"(#loc250)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc250)), %arg_DV: !tt.ptr loc("arg_DV"(#loc250)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc250)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc250)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc250)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc250)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc250)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc250)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc250)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc250)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc250)), %ks0: i32 loc("ks0"(#loc250)), %ks1: i32 loc("ks1"(#loc250)), %ks2: i32 loc("ks2"(#loc250)), %ks3: i32 loc("ks3"(#loc250)), %ks4: i32 loc("ks4"(#loc250)), %ks5: i32 loc("ks5"(#loc250)), %ks6: i32 loc("ks6"(#loc250)), %ks7: i32 loc("ks7"(#loc250)), %K: !tt.ptr loc("K"(#loc250)), %V: !tt.ptr loc("V"(#loc250)), %dq: tensor<128x128xf32> loc("dq"(#loc250)), %q: tensor<128x128xbf16> loc("q"(#loc250)), %do: tensor<128x128xbf16> loc("do"(#loc250)), %Di: tensor<128xf32> loc("Di"(#loc250)), %lse: tensor<128x1xf32> loc("lse"(#loc250)), %off_z: i32 loc("off_z"(#loc250)), %off_hq: i32 loc("off_hq"(#loc250)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc250)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc250)), %stride_kn: i32 loc("stride_kn"(#loc250)), %stride_kd: i32 loc("stride_kd"(#loc250)), %stride_vn: i32 loc("stride_vn"(#loc250)), %stride_vd: i32 loc("stride_vd"(#loc250)), %kv_indices: !tt.ptr loc("kv_indices"(#loc250)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc250))) -> tensor<128x128xf32> attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc755) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc756) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc757) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc758) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc758) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc759) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc759) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc760) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc761) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc761) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc762) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc762) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc762) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc763) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc764) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc764) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc765) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc765) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc766) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc767) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc767) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc768) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc768) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc768) + %hi = arith.constant 2 : i32 loc(#loc769) + %hi_20 = arith.constant 2 : i32 loc(#loc769) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc769) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc770) + %hi_23 = arith.constant 1 : i32 loc(#loc771) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc771) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc772) + %c0_i32 = arith.constant 0 : i32 loc(#loc269) + %c1_i32 = arith.constant 1 : i32 loc(#loc269) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc269) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc269) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc269) + %3 = ub.poison : i32 loc(#loc269) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(46,)cconstexpr_bf16__(47,)cconstexpr_1_d_44269504__(48,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %ks0, %ks1, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc774) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc775) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc776) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc777) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc777) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc778) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc779) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc779) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc780) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc780) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc277) + } loc(#loc1123) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc278) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc279) + tt.return %4 : tensor<128x128xf32> loc(#loc279) + } loc(#loc250) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(46,)cconstexpr_bf16__(47,)cconstexpr_1_d_44269504__(48,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc280)), %arg_K: !tt.ptr loc("arg_K"(#loc280)), %arg_V: !tt.ptr loc("arg_V"(#loc280)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc280)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc280)), %arg_DO: !tt.ptr loc("arg_DO"(#loc280)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc280)), %arg_DV: !tt.ptr loc("arg_DV"(#loc280)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc280)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc280)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc280)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc280)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc280)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc280)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc280)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc280)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc280)), %ks0: i32 loc("ks0"(#loc280)), %ks1: i32 loc("ks1"(#loc280)), %ks2: i32 loc("ks2"(#loc280)), %ks3: i32 loc("ks3"(#loc280)), %ks4: i32 loc("ks4"(#loc280)), %ks5: i32 loc("ks5"(#loc280)), %ks6: i32 loc("ks6"(#loc280)), %ks7: i32 loc("ks7"(#loc280)), %dq: tensor<128x128xf32> loc("dq"(#loc280)), %q: tensor<128x128xbf16> loc("q"(#loc280)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc280)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc280)), %do: tensor<128x128xbf16> loc("do"(#loc280)), %Di: tensor<128xf32> loc("Di"(#loc280)), %lse: tensor<128x1xf32> loc("lse"(#loc280)), %Q_LEN: i32 loc("Q_LEN"(#loc280)), %KV_LEN: i32 loc("KV_LEN"(#loc280)), %off_z: i32 loc("off_z"(#loc280)), %off_hq: i32 loc("off_hq"(#loc280)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc280)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc280)), %offs_k: tensor<128xi32> loc("offs_k"(#loc280)), %offs_v: tensor<128xi32> loc("offs_v"(#loc280)), %stride_kn: i32 loc("stride_kn"(#loc280)), %stride_kd: i32 loc("stride_kd"(#loc280)), %stride_vn: i32 loc("stride_vn"(#loc280)), %stride_vd: i32 loc("stride_vd"(#loc280)), %kv_indices: !tt.ptr loc("kv_indices"(#loc280)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc280))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc827) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc828) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc828) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc828) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc829) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc829) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc829) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc829) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc830) + %n_6 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S1_64S_i32__(%n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc831) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc832) + %m_7 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S128_1S_i32__(%m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc833) + %post_mod_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc834) + %post_mod_scores_8 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc835) + %post_mod_scores_9 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_8 : tensor<1x64xi32> loc(#loc835) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc836) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc836) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc836) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc836) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_5, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc836) + %post_mod_scores_15 = arith.constant 1.44269502 : f32 loc(#loc876) + %post_mod_scores_16 = arith.constant 1.44269502 : f32 loc(#loc876) + %post_mod_scores_17 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc876) + %post_mod_scores_18 = arith.mulf %post_mod_scores_14, %post_mod_scores_17 : tensor<128x64xf32> loc(#loc876) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc877) + %p_19 = arith.subf %post_mod_scores_18, %p : tensor<128x64xf32> loc(#loc877) + %p_20 = math.exp2 %p_19 : tensor<128x64xf32> loc(#loc878) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc879) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc880) + %dp_21 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc880) + %dp_22 = tt.dot %do, %vT, %dp_21, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc880) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc881) + %ds_23 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc882) + %ds_24 = arith.subf %dp_22, %ds_23 : tensor<128x64xf32> loc(#loc882) + %ds_25 = arith.mulf %p_20, %ds_24 : tensor<128x64xf32> loc(#loc883) + %grad_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc884) + %grad_scores_26 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc885) + %grad_scores_27 = arith.cmpi slt, %grad_scores, %grad_scores_26 : tensor<1x64xi32> loc(#loc885) + %grad_scores_28 = arith.constant 0.000000e+00 : f32 loc(#loc886) + %grad_scores_29 = arith.constant 0.000000e+00 : f32 loc(#loc886) + %grad_scores_30 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc886) + %grad_scores_31 = tt.broadcast %grad_scores_27 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc886) + %grad_scores_32 = arith.select %grad_scores_31, %ds_25, %grad_scores_30 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc886) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc887) + %scatter_mask_33 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc888) + %scatter_mask_34 = arith.cmpi slt, %scatter_mask, %scatter_mask_33 : tensor<128x1xi32> loc(#loc888) + %scatter_mask_35 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc889) + %scatter_mask_36 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc890) + %scatter_mask_37 = arith.cmpi slt, %scatter_mask_35, %scatter_mask_36 : tensor<1x64xi32> loc(#loc890) + %scatter_mask_38 = tt.broadcast %scatter_mask_34 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc891) + %scatter_mask_39 = tt.broadcast %scatter_mask_37 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc891) + %scatter_mask_40 = arith.andi %scatter_mask_38, %scatter_mask_39 : tensor<128x64xi1> loc(#loc891) + %ds_41 = arith.truncf %grad_scores_32 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc893) + %dq_42 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc894) + %dq_43 = arith.constant 0.000000e+00 : f32 loc(#loc895) + %dq_44 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc895) + %dq_45 = tt.dot %ds_41, %dq_42, %dq_44, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc895) + %dq_46 = arith.addf %dq, %dq_45 : tensor<128x128xf32> loc(#loc896) + tt.return %dq_46 : tensor<128x128xf32> loc(#loc351) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc352) + tt.return %0 : tensor<128x128xf32> loc(#loc352) + } loc(#loc280) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(43,)cconstexpr_bf16__(44,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc382)), %arg_K: !tt.ptr loc("arg_K"(#loc382)), %arg_V: !tt.ptr loc("arg_V"(#loc382)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc382)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc382)), %arg_DO: !tt.ptr loc("arg_DO"(#loc382)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc382)), %arg_DV: !tt.ptr loc("arg_DV"(#loc382)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc382)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc382)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc382)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc382)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc382)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc382)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc382)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc382)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc382)), %ks0: i32 loc("ks0"(#loc382)), %ks1: i32 loc("ks1"(#loc382)), %ks2: i32 loc("ks2"(#loc382)), %ks3: i32 loc("ks3"(#loc382)), %ks4: i32 loc("ks4"(#loc382)), %ks5: i32 loc("ks5"(#loc382)), %ks6: i32 loc("ks6"(#loc382)), %ks7: i32 loc("ks7"(#loc382)), %Q: !tt.ptr loc("Q"(#loc382)), %DO: !tt.ptr loc("DO"(#loc382)), %DELTA: !tt.ptr loc("DELTA"(#loc382)), %LSE: !tt.ptr loc("LSE"(#loc382)), %dk: tensor<128x128xf32> loc("dk"(#loc382)), %dv: tensor<128x128xf32> loc("dv"(#loc382)), %k: tensor<128x128xbf16> loc("k"(#loc382)), %v: tensor<128x128xbf16> loc("v"(#loc382)), %off_z: i32 loc("off_z"(#loc382)), %off_hq: i32 loc("off_hq"(#loc382)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc382)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc382)), %stride_qm: i32 loc("stride_qm"(#loc382)), %stride_qd: i32 loc("stride_qd"(#loc382)), %stride_dom: i32 loc("stride_dom"(#loc382)), %stride_dod: i32 loc("stride_dod"(#loc382)), %q_indices: !tt.ptr loc("q_indices"(#loc382)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc382))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc964) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc965) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc966) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc967) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc967) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc968) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc968) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc969) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc970) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc970) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc971) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc971) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc971) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc972) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc973) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc973) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc974) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc974) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc975) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc976) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc976) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc977) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc977) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc977) + %hi = arith.constant 2 : i32 loc(#loc978) + %hi_20 = arith.constant 2 : i32 loc(#loc978) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc978) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks0) : (i32) -> i32 loc(#loc979) + %hi_23 = arith.constant 1 : i32 loc(#loc980) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc980) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc981) + %c0_i32 = arith.constant 0 : i32 loc(#loc401) + %c1_i32 = arith.constant 1 : i32 loc(#loc401) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc401) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc401) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc401) + %3 = ub.poison : i32 loc(#loc401) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(47,)cconstexpr_bf16__(48,)cconstexpr_1_d_44269504__(49,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %ks0, %ks1, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc402) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc983) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc984) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc985) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc985) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc986) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc987) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc987) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc988) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc988) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc409) + } loc(#loc1125) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc410) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc411) + %5 = ub.poison : tensor<128x128xf32> loc(#loc411) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc411) + } loc(#loc382) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(47,)cconstexpr_bf16__(48,)cconstexpr_1_d_44269504__(49,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc412)), %arg_K: !tt.ptr loc("arg_K"(#loc412)), %arg_V: !tt.ptr loc("arg_V"(#loc412)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc412)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc412)), %arg_DO: !tt.ptr loc("arg_DO"(#loc412)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc412)), %arg_DV: !tt.ptr loc("arg_DV"(#loc412)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc412)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc412)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc412)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc412)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc412)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc412)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc412)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc412)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc412)), %ks0: i32 loc("ks0"(#loc412)), %ks1: i32 loc("ks1"(#loc412)), %ks2: i32 loc("ks2"(#loc412)), %ks3: i32 loc("ks3"(#loc412)), %ks4: i32 loc("ks4"(#loc412)), %ks5: i32 loc("ks5"(#loc412)), %ks6: i32 loc("ks6"(#loc412)), %ks7: i32 loc("ks7"(#loc412)), %dk: tensor<128x128xf32> loc("dk"(#loc412)), %dv: tensor<128x128xf32> loc("dv"(#loc412)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc412)), %k: tensor<128x128xbf16> loc("k"(#loc412)), %v: tensor<128x128xbf16> loc("v"(#loc412)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc412)), %DELTA: !tt.ptr loc("DELTA"(#loc412)), %LSE: !tt.ptr loc("LSE"(#loc412)), %Q_LEN: i32 loc("Q_LEN"(#loc412)), %KV_LEN: i32 loc("KV_LEN"(#loc412)), %off_z: i32 loc("off_z"(#loc412)), %off_hq: i32 loc("off_hq"(#loc412)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc412)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc412)), %offs_k: tensor<128xi32> loc("offs_k"(#loc412)), %offs_v: tensor<128xi32> loc("offs_v"(#loc412)), %stride_qm: i32 loc("stride_qm"(#loc412)), %stride_qd: i32 loc("stride_qd"(#loc412)), %stride_dom: i32 loc("stride_dom"(#loc412)), %stride_dod: i32 loc("stride_dod"(#loc412)), %q_indices: !tt.ptr loc("q_indices"(#loc412)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc412))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc1036) + %lse = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1037) + %lse_0 = arith.cmpi slt, %offs_m1, %lse : tensor<64xi32> loc(#loc1037) + %lse_1 = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1038) + %lse_2 = tt.addptr %lse_1, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1038) + %lse_3 = tt.load %lse_2, %lse_0 : tensor<64x!tt.ptr> loc(#loc1039) + %lse_4 = arith.constant 0xFF800000 : f32 loc(#loc1040) + %lse_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1040) + %lse_6 = arith.cmpf oeq, %lse_3, %lse_5 : tensor<64xf32> loc(#loc1040) + %lse_7 = arith.constant 0.000000e+00 : f32 loc(#loc1041) + %lse_8 = arith.constant 0.000000e+00 : f32 loc(#loc1041) + %lse_9 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1041) + %lse_10 = arith.select %lse_6, %lse_9, %lse_3 : tensor<64xi1>, tensor<64xf32> loc(#loc1041) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc1042) + %qkT_11 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1042) + %qkT_12 = tt.dot %k, %qT, %qkT_11, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1042) + %qkT_13 = arith.constant 0.0883883461 : f32 loc(#loc1043) + %qkT_14 = arith.constant 0.0883883461 : f32 loc(#loc1043) + %qkT_15 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1043) + %qkT_16 = arith.mulf %qkT_12, %qkT_15 : tensor<128x64xf32> loc(#loc1043) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1044) + %m_17 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S1_64S_i32__(%m, %Q_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc1045) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc1046) + %n_18 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S128_1S_i32__(%n, %KV_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc1047) + %post_mod_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1048) + %post_mod_scores_19 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1049) + %post_mod_scores_20 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_19 : tensor<1x64xi32> loc(#loc1049) + %post_mod_scores_21 = arith.constant 0xFF800000 : f32 loc(#loc1050) + %post_mod_scores_22 = arith.constant 0xFF800000 : f32 loc(#loc1050) + %post_mod_scores_23 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1050) + %post_mod_scores_24 = tt.broadcast %post_mod_scores_20 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1050) + %post_mod_scores_25 = arith.select %post_mod_scores_24, %qkT_16, %post_mod_scores_23 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1050) + %tmp24 = arith.constant 0 : i32 loc(#loc1051) + %tmp24_26 = arith.constant dense<0> : tensor<1xi32> loc(#loc1051) + %tmp25 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1052) + %tmp25_27 = tt.broadcast %tmp25 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1052) + %tmp25_28 = arith.cmpi slt, %m_17, %tmp25_27 : tensor<1x64xi32> loc(#loc1052) + %tmp27 = tt.broadcast %n_18 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc1053) + %tmp27_29 = tt.broadcast %m_17 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc1053) + %tmp27_30 = arith.cmpi sle, %tmp27, %tmp27_29 : tensor<128x64xi32> loc(#loc1053) + %tmp28 = tt.broadcast %tmp25_28 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1054) + %tmp28_31 = arith.andi %tmp28, %tmp27_30 : tensor<128x64xi1> loc(#loc1054) + %tmp29 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1055) + %tmp29_32 = tt.broadcast %tmp29 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1055) + %tmp29_33 = arith.cmpi sge, %m_17, %tmp29_32 : tensor<1x64xi32> loc(#loc1055) + %tmp30 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1056) + %tmp30_34 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1056) + %tmp30_35 = arith.cmpi slt, %n_18, %tmp30_34 : tensor<128x1xi32> loc(#loc1056) + %tmp31 = tt.broadcast %tmp29_33 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1057) + %tmp31_36 = tt.broadcast %tmp30_35 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc1057) + %tmp31_37 = arith.andi %tmp31, %tmp31_36 : tensor<128x64xi1> loc(#loc1057) + %tmp32 = arith.constant 0 : i32 loc(#loc1058) + %tmp32_38 = arith.extui %tmp30_35 : tensor<128x1xi1> to tensor<128x1xi32> loc(#loc1058) + %tmp32_39 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1058) + %tmp32_40 = arith.cmpi eq, %tmp32_38, %tmp32_39 : tensor<128x1xi32> loc(#loc1058) + %tmp33 = tt.broadcast %tmp29_33 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1059) + %tmp33_41 = tt.broadcast %tmp32_40 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc1059) + %tmp33_42 = arith.andi %tmp33, %tmp33_41 : tensor<128x64xi1> loc(#loc1059) + %tmp34 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1060) + %tmp34_43 = tt.broadcast %tmp34 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1060) + %tmp34_44 = arith.subi %m_17, %tmp34_43 : tensor<1x64xi32> loc(#loc1060) + %tmp35 = arith.constant 16 : i32 loc(#loc1061) + %tmp35_45 = arith.constant dense<16> : tensor<1xi32> loc(#loc1061) + %tmp36 = arith.constant 0 : i32 loc(#loc1062) + %tmp36_46 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1062) + %tmp36_47 = arith.cmpi slt, %tmp34_44, %tmp36_46 : tensor<1x64xi32> loc(#loc1062) + %tmp36_48 = arith.constant 0 : i32 loc(#loc1063) + %tmp36_49 = arith.constant dense<0> : tensor<1xi32> loc(#loc1063) + %tmp36_50 = arith.cmpi slt, %tmp35_45, %tmp36_49 : tensor<1xi32> loc(#loc1063) + %tmp36_51 = tt.expand_dims %tmp36_50 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc1064) + %tmp36_52 = tt.broadcast %tmp36_51 : tensor<1x1xi1> -> tensor<1x64xi1> loc(#loc1064) + %tmp36_53 = arith.cmpi ne, %tmp36_47, %tmp36_52 : tensor<1x64xi1> loc(#loc1064) + %tmp36_54 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1065) + %tmp36_55 = tt.broadcast %tmp36_54 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1065) + %tmp36_56 = arith.remsi %tmp34_44, %tmp36_55 : tensor<1x64xi32> loc(#loc1065) + %tmp36_57 = arith.constant 0 : i32 loc(#loc1066) + %tmp36_58 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1066) + %tmp36_59 = arith.cmpi ne, %tmp36_56, %tmp36_58 : tensor<1x64xi32> loc(#loc1066) + %tmp36_60 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1067) + %tmp36_61 = tt.broadcast %tmp36_60 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1067) + %tmp36_62 = arith.divsi %tmp34_44, %tmp36_61 : tensor<1x64xi32> loc(#loc1067) + %tmp36_63 = arith.constant 1 : i32 loc(#loc1068) + %tmp36_64 = arith.constant 1 : i32 loc(#loc1068) + %tmp36_65 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc1068) + %tmp36_66 = arith.subi %tmp36_62, %tmp36_65 : tensor<1x64xi32> loc(#loc1068) + %tmp36_67 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1069) + %tmp36_68 = tt.broadcast %tmp36_67 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1069) + %tmp36_69 = arith.divsi %tmp34_44, %tmp36_68 : tensor<1x64xi32> loc(#loc1069) + %tmp36_70 = arith.select %tmp36_59, %tmp36_66, %tmp36_69 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc1070) + %tmp36_71 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1071) + %tmp36_72 = tt.broadcast %tmp36_71 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1071) + %tmp36_73 = arith.divsi %tmp34_44, %tmp36_72 : tensor<1x64xi32> loc(#loc1071) + %tmp36_74 = arith.select %tmp36_53, %tmp36_70, %tmp36_73 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc1072) + %tmp37 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1073) + %tmp37_75 = tt.broadcast %tmp37 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1073) + %tmp37_76 = arith.subi %n_18, %tmp37_75 : tensor<128x1xi32> loc(#loc1073) + %tmp38 = arith.constant 0 : i32 loc(#loc1074) + %tmp38_77 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1074) + %tmp38_78 = arith.cmpi slt, %tmp37_76, %tmp38_77 : tensor<128x1xi32> loc(#loc1074) + %tmp38_79 = arith.constant 0 : i32 loc(#loc1075) + %tmp38_80 = arith.constant dense<0> : tensor<1xi32> loc(#loc1075) + %tmp38_81 = arith.cmpi slt, %tmp35_45, %tmp38_80 : tensor<1xi32> loc(#loc1075) + %tmp38_82 = tt.expand_dims %tmp38_81 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc1076) + %tmp38_83 = tt.broadcast %tmp38_82 : tensor<1x1xi1> -> tensor<128x1xi1> loc(#loc1076) + %tmp38_84 = arith.cmpi ne, %tmp38_78, %tmp38_83 : tensor<128x1xi1> loc(#loc1076) + %tmp38_85 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1077) + %tmp38_86 = tt.broadcast %tmp38_85 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1077) + %tmp38_87 = arith.remsi %tmp37_76, %tmp38_86 : tensor<128x1xi32> loc(#loc1077) + %tmp38_88 = arith.constant 0 : i32 loc(#loc1078) + %tmp38_89 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1078) + %tmp38_90 = arith.cmpi ne, %tmp38_87, %tmp38_89 : tensor<128x1xi32> loc(#loc1078) + %tmp38_91 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1079) + %tmp38_92 = tt.broadcast %tmp38_91 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1079) + %tmp38_93 = arith.divsi %tmp37_76, %tmp38_92 : tensor<128x1xi32> loc(#loc1079) + %tmp38_94 = arith.constant 1 : i32 loc(#loc1080) + %tmp38_95 = arith.constant 1 : i32 loc(#loc1080) + %tmp38_96 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc1080) + %tmp38_97 = arith.subi %tmp38_93, %tmp38_96 : tensor<128x1xi32> loc(#loc1080) + %tmp38_98 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1081) + %tmp38_99 = tt.broadcast %tmp38_98 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1081) + %tmp38_100 = arith.divsi %tmp37_76, %tmp38_99 : tensor<128x1xi32> loc(#loc1081) + %tmp38_101 = arith.select %tmp38_90, %tmp38_97, %tmp38_100 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc1082) + %tmp38_102 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1083) + %tmp38_103 = tt.broadcast %tmp38_102 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1083) + %tmp38_104 = arith.divsi %tmp37_76, %tmp38_103 : tensor<128x1xi32> loc(#loc1083) + %tmp38_105 = arith.select %tmp38_84, %tmp38_101, %tmp38_104 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc1084) + %tmp39 = tt.broadcast %tmp36_74 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc1085) + %tmp39_106 = tt.broadcast %tmp38_105 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc1085) + %tmp39_107 = arith.cmpi eq, %tmp39, %tmp39_106 : tensor<128x64xi32> loc(#loc1085) + %tmp40 = arith.andi %tmp33_42, %tmp39_107 : tensor<128x64xi1> loc(#loc1086) + %tmp41 = arith.ori %tmp31_37, %tmp40 : tensor<128x64xi1> loc(#loc1087) + %tmp42 = arith.ori %tmp28_31, %tmp41 : tensor<128x64xi1> loc(#loc1088) + %post_mod_scores_108 = arith.constant 0xFF800000 : f32 loc(#loc1089) + %post_mod_scores_109 = arith.constant 0xFF800000 : f32 loc(#loc1089) + %post_mod_scores_110 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1089) + %post_mod_scores_111 = arith.select %tmp42, %post_mod_scores_25, %post_mod_scores_110 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1089) + %post_mod_scores_112 = arith.constant 1.44269502 : f32 loc(#loc1090) + %post_mod_scores_113 = arith.constant 1.44269502 : f32 loc(#loc1090) + %post_mod_scores_114 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1090) + %post_mod_scores_115 = arith.mulf %post_mod_scores_111, %post_mod_scores_114 : tensor<128x64xf32> loc(#loc1090) + %pT = tt.expand_dims %lse_10 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1091) + %pT_116 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1092) + %pT_117 = arith.subf %post_mod_scores_115, %pT_116 : tensor<128x64xf32> loc(#loc1092) + %pT_118 = math.exp2 %pT_117 : tensor<128x64xf32> loc(#loc1093) + %do = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1094) + %dv_119 = arith.truncf %pT_118 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1095) + %dv_120 = arith.constant 0.000000e+00 : f32 loc(#loc1096) + %dv_121 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1096) + %dv_122 = tt.dot %dv_119, %do, %dv_121, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1096) + %dv_123 = arith.addf %dv, %dv_122 : tensor<128x128xf32> loc(#loc1097) + %Di = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1098) + %Di_124 = arith.cmpi slt, %offs_m1, %Di : tensor<64xi32> loc(#loc1098) + %Di_125 = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1099) + %Di_126 = tt.addptr %Di_125, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1099) + %Di_127 = tt.load %Di_126, %Di_124 : tensor<64x!tt.ptr> loc(#loc1100) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1101) + %dpT_128 = arith.constant 0.000000e+00 : f32 loc(#loc1102) + %dpT_129 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1102) + %dpT_130 = tt.dot %v, %dpT, %dpT_129, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1102) + %dsT = tt.expand_dims %Di_127 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1103) + %dsT_131 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1104) + %dsT_132 = arith.subf %dpT_130, %dsT_131 : tensor<128x64xf32> loc(#loc1104) + %dsT_133 = arith.mulf %pT_118, %dsT_132 : tensor<128x64xf32> loc(#loc1105) + %grad_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1106) + %grad_scores_134 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1107) + %grad_scores_135 = arith.cmpi slt, %grad_scores, %grad_scores_134 : tensor<1x64xi32> loc(#loc1107) + %grad_scores_136 = arith.constant 0.000000e+00 : f32 loc(#loc1108) + %grad_scores_137 = arith.constant 0.000000e+00 : f32 loc(#loc1108) + %grad_scores_138 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1108) + %grad_scores_139 = tt.broadcast %grad_scores_135 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1108) + %grad_scores_140 = arith.select %grad_scores_139, %dsT_133, %grad_scores_138 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1108) + %dsT_141 = arith.constant 0.000000e+00 : f32 loc(#loc1109) + %dsT_142 = arith.constant 0.000000e+00 : f32 loc(#loc1109) + %dsT_143 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1109) + %dsT_144 = arith.select %tmp42, %grad_scores_140, %dsT_143 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1109) + %dk_145 = arith.truncf %dsT_144 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1110) + %dk_146 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1111) + %dk_147 = arith.constant 0.000000e+00 : f32 loc(#loc1112) + %dk_148 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1112) + %dk_149 = tt.dot %dk_145, %dk_146, %dk_148, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1112) + %dk_150 = arith.addf %dk, %dk_149 : tensor<128x128xf32> loc(#loc1113) + tt.return %dk_150, %dv_123 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc491) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc492) + %1 = ub.poison : tensor<128x128xf32> loc(#loc492) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc492) + } loc(#loc412) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: tensor<64x128x!tt.ptr> loc("ptr"(#loc238)), %offs_m: tensor<64xi32> loc("offs_m"(#loc238)), %offs_n: tensor<128xi32> loc("offs_n"(#loc238)), %M_LEN: i32 loc("M_LEN"(#loc238))) -> tensor<64x128xbf16> attributes {noinline = false} { + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc245) + %1 = tt.splat %M_LEN : i32 -> tensor<64x1xi32> loc(#loc246) + %2 = arith.cmpi slt, %0, %1 : tensor<64x1xi32> loc(#loc246) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc247) + %3 = tt.broadcast %2 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc247) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x128xf32> loc(#loc247) + %4 = arith.truncf %cst_0 : tensor<64x128xf32> to tensor<64x128xbf16> loc(#loc247) + %5 = tt.load %ptr, %3, %4 : tensor<64x128x!tt.ptr> loc(#loc247) + tt.return %5 : tensor<64x128xbf16> loc(#loc248) + ^bb1: // no predecessors + %6 = ub.poison : tensor<64x128xbf16> loc(#loc249) + tt.return %6 : tensor<64x128xbf16> loc(#loc249) + } loc(#loc238) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(43,)cconstexpr_bf16__(44,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc382)), %arg_K: !tt.ptr loc("arg_K"(#loc382)), %arg_V: !tt.ptr loc("arg_V"(#loc382)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc382)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc382)), %arg_DO: !tt.ptr loc("arg_DO"(#loc382)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc382)), %arg_DV: !tt.ptr loc("arg_DV"(#loc382)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc382)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc382)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc382)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc382)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc382)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc382)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc382)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc382)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc382)), %ks0: i32 loc("ks0"(#loc382)), %ks1: i32 loc("ks1"(#loc382)), %ks2: i32 loc("ks2"(#loc382)), %ks3: i32 loc("ks3"(#loc382)), %ks4: i32 loc("ks4"(#loc382)), %ks5: i32 loc("ks5"(#loc382)), %ks6: i32 loc("ks6"(#loc382)), %ks7: i32 loc("ks7"(#loc382)), %Q: !tt.ptr loc("Q"(#loc382)), %DO: !tt.ptr loc("DO"(#loc382)), %DELTA: !tt.ptr loc("DELTA"(#loc382)), %LSE: !tt.ptr loc("LSE"(#loc382)), %dk: tensor<128x128xf32> loc("dk"(#loc382)), %dv: tensor<128x128xf32> loc("dv"(#loc382)), %k: tensor<128x128xbf16> loc("k"(#loc382)), %v: tensor<128x128xbf16> loc("v"(#loc382)), %off_z: i32 loc("off_z"(#loc382)), %off_hq: i32 loc("off_hq"(#loc382)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc382)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc382)), %stride_qm: i32 loc("stride_qm"(#loc382)), %stride_qd: i32 loc("stride_qd"(#loc382)), %stride_dom: i32 loc("stride_dom"(#loc382)), %stride_dod: i32 loc("stride_dod"(#loc382)), %q_indices: !tt.ptr loc("q_indices"(#loc382)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc382))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc964) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc965) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc966) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc967) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc967) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc968) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc968) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc969) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc970) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc970) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc971) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc971) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc971) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc972) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc973) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc973) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc974) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc974) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc975) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc976) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc976) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc977) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc977) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc977) + %hi = arith.constant 2 : i32 loc(#loc978) + %hi_20 = arith.constant 2 : i32 loc(#loc978) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc978) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks0) : (i32) -> i32 loc(#loc979) + %hi_23 = arith.constant 1 : i32 loc(#loc980) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc980) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc981) + %c0_i32 = arith.constant 0 : i32 loc(#loc401) + %c1_i32 = arith.constant 1 : i32 loc(#loc401) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc401) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc401) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc401) + %3 = ub.poison : i32 loc(#loc401) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(47,)cconstexpr_bf16__(48,)cconstexpr_1_d_44269504__(49,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %ks2, %ks3, %ks4, %ks5, %ks6, %ks7, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %ks0, %ks1, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, i32, i32, i32, i32, i32, i32, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc402) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc983) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc984) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc985) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc985) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc986) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc987) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc987) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc988) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc988) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc409) + } loc(#loc1125) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc410) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc411) + %5 = ub.poison : tensor<128x128xf32> loc(#loc411) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc411) + } loc(#loc382) + tt.func private @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_i32_i32_i32_i32_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(47,)cconstexpr_bf16__(48,)cconstexpr_1_d_44269504__(49,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc412)), %arg_K: !tt.ptr loc("arg_K"(#loc412)), %arg_V: !tt.ptr loc("arg_V"(#loc412)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc412)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc412)), %arg_DO: !tt.ptr loc("arg_DO"(#loc412)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc412)), %arg_DV: !tt.ptr loc("arg_DV"(#loc412)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc412)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc412)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc412)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc412)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc412)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc412)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc412)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc412)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc412)), %ks0: i32 loc("ks0"(#loc412)), %ks1: i32 loc("ks1"(#loc412)), %ks2: i32 loc("ks2"(#loc412)), %ks3: i32 loc("ks3"(#loc412)), %ks4: i32 loc("ks4"(#loc412)), %ks5: i32 loc("ks5"(#loc412)), %ks6: i32 loc("ks6"(#loc412)), %ks7: i32 loc("ks7"(#loc412)), %dk: tensor<128x128xf32> loc("dk"(#loc412)), %dv: tensor<128x128xf32> loc("dv"(#loc412)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc412)), %k: tensor<128x128xbf16> loc("k"(#loc412)), %v: tensor<128x128xbf16> loc("v"(#loc412)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc412)), %DELTA: !tt.ptr loc("DELTA"(#loc412)), %LSE: !tt.ptr loc("LSE"(#loc412)), %Q_LEN: i32 loc("Q_LEN"(#loc412)), %KV_LEN: i32 loc("KV_LEN"(#loc412)), %off_z: i32 loc("off_z"(#loc412)), %off_hq: i32 loc("off_hq"(#loc412)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc412)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc412)), %offs_k: tensor<128xi32> loc("offs_k"(#loc412)), %offs_v: tensor<128xi32> loc("offs_v"(#loc412)), %stride_qm: i32 loc("stride_qm"(#loc412)), %stride_qd: i32 loc("stride_qd"(#loc412)), %stride_dom: i32 loc("stride_dom"(#loc412)), %stride_dod: i32 loc("stride_dod"(#loc412)), %q_indices: !tt.ptr loc("q_indices"(#loc412)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc412))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc1036) + %lse = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1037) + %lse_0 = arith.cmpi slt, %offs_m1, %lse : tensor<64xi32> loc(#loc1037) + %lse_1 = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1038) + %lse_2 = tt.addptr %lse_1, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1038) + %lse_3 = tt.load %lse_2, %lse_0 : tensor<64x!tt.ptr> loc(#loc1039) + %lse_4 = arith.constant 0xFF800000 : f32 loc(#loc1040) + %lse_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1040) + %lse_6 = arith.cmpf oeq, %lse_3, %lse_5 : tensor<64xf32> loc(#loc1040) + %lse_7 = arith.constant 0.000000e+00 : f32 loc(#loc1041) + %lse_8 = arith.constant 0.000000e+00 : f32 loc(#loc1041) + %lse_9 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1041) + %lse_10 = arith.select %lse_6, %lse_9, %lse_3 : tensor<64xi1>, tensor<64xf32> loc(#loc1041) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc1042) + %qkT_11 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1042) + %qkT_12 = tt.dot %k, %qT, %qkT_11, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1042) + %qkT_13 = arith.constant 0.0883883461 : f32 loc(#loc1043) + %qkT_14 = arith.constant 0.0883883461 : f32 loc(#loc1043) + %qkT_15 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1043) + %qkT_16 = arith.mulf %qkT_12, %qkT_15 : tensor<128x64xf32> loc(#loc1043) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1044) + %m_17 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S1_64S_i32__(%m, %Q_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc1045) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc1046) + %n_18 = tt.call @torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.get_bounded_indices__i32S128_1S_i32__(%n, %KV_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc1047) + %post_mod_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1048) + %post_mod_scores_19 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1049) + %post_mod_scores_20 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_19 : tensor<1x64xi32> loc(#loc1049) + %post_mod_scores_21 = arith.constant 0xFF800000 : f32 loc(#loc1050) + %post_mod_scores_22 = arith.constant 0xFF800000 : f32 loc(#loc1050) + %post_mod_scores_23 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1050) + %post_mod_scores_24 = tt.broadcast %post_mod_scores_20 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1050) + %post_mod_scores_25 = arith.select %post_mod_scores_24, %qkT_16, %post_mod_scores_23 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1050) + %post_mod_scores_26 = arith.constant 1.44269502 : f32 loc(#loc1090) + %post_mod_scores_27 = arith.constant 1.44269502 : f32 loc(#loc1090) + %post_mod_scores_28 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1090) + %post_mod_scores_29 = arith.mulf %post_mod_scores_25, %post_mod_scores_28 : tensor<128x64xf32> loc(#loc1090) + %pT = tt.expand_dims %lse_10 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1091) + %pT_30 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1092) + %pT_31 = arith.subf %post_mod_scores_29, %pT_30 : tensor<128x64xf32> loc(#loc1092) + %pT_32 = math.exp2 %pT_31 : tensor<128x64xf32> loc(#loc1093) + %do = tt.call @"torch._inductor.runtime.compile_tasks.c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1094) + %dv_33 = arith.truncf %pT_32 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1095) + %dv_34 = arith.constant 0.000000e+00 : f32 loc(#loc1096) + %dv_35 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1096) + %dv_36 = tt.dot %dv_33, %do, %dv_35, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1096) + %dv_37 = arith.addf %dv, %dv_36 : tensor<128x128xf32> loc(#loc1097) + %Di = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1098) + %Di_38 = arith.cmpi slt, %offs_m1, %Di : tensor<64xi32> loc(#loc1098) + %Di_39 = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1099) + %Di_40 = tt.addptr %Di_39, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1099) + %Di_41 = tt.load %Di_40, %Di_38 : tensor<64x!tt.ptr> loc(#loc1100) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1101) + %dpT_42 = arith.constant 0.000000e+00 : f32 loc(#loc1102) + %dpT_43 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1102) + %dpT_44 = tt.dot %v, %dpT, %dpT_43, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1102) + %dsT = tt.expand_dims %Di_41 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1103) + %dsT_45 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1104) + %dsT_46 = arith.subf %dpT_44, %dsT_45 : tensor<128x64xf32> loc(#loc1104) + %dsT_47 = arith.mulf %pT_32, %dsT_46 : tensor<128x64xf32> loc(#loc1105) + %grad_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1106) + %grad_scores_48 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1107) + %grad_scores_49 = arith.cmpi slt, %grad_scores, %grad_scores_48 : tensor<1x64xi32> loc(#loc1107) + %grad_scores_50 = arith.constant 0.000000e+00 : f32 loc(#loc1108) + %grad_scores_51 = arith.constant 0.000000e+00 : f32 loc(#loc1108) + %grad_scores_52 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1108) + %grad_scores_53 = tt.broadcast %grad_scores_49 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1108) + %grad_scores_54 = arith.select %grad_scores_53, %dsT_47, %grad_scores_52 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1108) + %dk_55 = arith.truncf %grad_scores_54 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1110) + %dk_56 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1111) + %dk_57 = arith.constant 0.000000e+00 : f32 loc(#loc1112) + %dk_58 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1112) + %dk_59 = tt.dot %dk_55, %dk_56, %dk_58, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1112) + %dk_60 = arith.addf %dk, %dk_59 : tensor<128x128xf32> loc(#loc1113) + tt.return %dk_60, %dv_37 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc491) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc492) + %1 = ub.poison : tensor<128x128xf32> loc(#loc492) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc492) + } loc(#loc412) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":94:54) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":94:49) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":95:54) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":95:49) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":96:54) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":96:49) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:74) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:66) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:100) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:91) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:82) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:59) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:126) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:118) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:152) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:143) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:134) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:111) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:53) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":99:58) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":99:53) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":100:58) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":100:53) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":102:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":103:9) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":104:10) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":106:10) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":111:24) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":112:36) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":113:34) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":115:27) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":116:28) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":117:23) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":119:15) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":120:16) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":122:28) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:25) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:47) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:35) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:59) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":125:25) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":125:47) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":125:35) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":125:59) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:27) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:50) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:37) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:61) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":131:9) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":132:9) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":133:10) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":135:14) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":136:26) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":137:26) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:14) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:7) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":140:24) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":142:29) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":143:30) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:29) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:54) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:44) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":145:35) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":146:41) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":148:30) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":151:35) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":152:42) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":152:54) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":154:55) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":154:78) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":155:50) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":155:83) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":155:68) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:30) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:52) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:40) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:63) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:32) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:55) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:42) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:66) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":160:32) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":160:55) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":160:42) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":160:66) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:30) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:35) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:46) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:56) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":163:17) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":164:19) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":167:19) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":168:21) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":169:25) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":172:22) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":174:36) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":175:42) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":175:29) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":178:107) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":179:111) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:58) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:34) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:25) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:57) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:33) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:26) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:30) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:50) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":191:18) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":195:30) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:27) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:41) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:53) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:39) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:42) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:29) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":207:12) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":214:39) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:31) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:45) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:62) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:43) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":218:46) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":218:33) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":226:16) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:32) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:43) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:24) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:63) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:74) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:56) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":232:14) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:48) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:59) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:76) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:87) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:69) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:30) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":239:29) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":240:30) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":242:26) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":245:29) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":249:22) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":250:22) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":252:25) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":253:42) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":253:29) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":256:107) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":257:107) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":262:30) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:32) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:51) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:34) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:56) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:44) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:67) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:36) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:59) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:46) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:70) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":268:36) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":268:59) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":268:46) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":268:70) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:34) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:39) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:50) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:60) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":271:21) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":272:23) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":275:25) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":276:29) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":278:39) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":279:46) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":279:58) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":281:58) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":281:80) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":282:53) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":282:81) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":282:70) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":286:32) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:30) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:43) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:55) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:42) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:45) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:32) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":298:16) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":306:41) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:34) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:47) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:64) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:46) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":310:49) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":310:36) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":318:20) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":303:12) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:31) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:42) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:23) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:62) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:73) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:55) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":325:26) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":326:25) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":327:25) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:50) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:71) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:61) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:30) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":334:14) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":337:29) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:31) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:27) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:45) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:53) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:41) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:64) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:71) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":344:59) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:59) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:55) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:74) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:69) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:29) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:99) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:4) +#loc229 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:16) +#loc230 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc231 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc232 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:11) +#loc233 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:4) +#loc234 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc235 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc236 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc237 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc239 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:27) +#loc240 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:38) +#loc241 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:20) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:56) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:67) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:49) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:41) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:52) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:23) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:15) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":792:4) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":387:26) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":388:26) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:26) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:37) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:18) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:56) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:67) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:49) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:26) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:37) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:18) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:56) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:67) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:49) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:43) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:90) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:101) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:63) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":397:28) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":405:12) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":411:64) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:28) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:19) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":415:28) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":415:19) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":417:19) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":417:8) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":419:11) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":419:4) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":458:105) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":459:19) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":461:14) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":464:36) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":464:46) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":467:36) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":467:46) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":476:43) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":476:54) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":476:79) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":480:31) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":481:22) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":483:23) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":484:22) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":485:23) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":486:22) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":487:22) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":488:24) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":489:23) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":490:23) +#loc301 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":491:33) +#loc302 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:34) +#loc303 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:49) +#loc304 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:41) +#loc305 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:70) +#loc306 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:79) +#loc307 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:91) +#loc308 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:99) +#loc309 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:111) +#loc310 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:102) +#loc311 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:128) +#loc312 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:119) +#loc313 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":493:23) +#loc314 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:34) +#loc315 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:49) +#loc316 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:41) +#loc317 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:70) +#loc318 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:79) +#loc319 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:91) +#loc320 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:99) +#loc321 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:111) +#loc322 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:102) +#loc323 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:128) +#loc324 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:119) +#loc325 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":495:25) +#loc326 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":496:24) +#loc327 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":497:23) +#loc328 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":498:23) +#loc329 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":503:69) +#loc330 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":506:27) +#loc331 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:39) +#loc332 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:21) +#loc333 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":510:104) +#loc334 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":512:20) +#loc335 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:22) +#loc336 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:19) +#loc337 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:14) +#loc338 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":520:39) +#loc339 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":520:50) +#loc340 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":520:71) +#loc341 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":524:32) +#loc342 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":524:43) +#loc343 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":524:62) +#loc344 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":524:73) +#loc345 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":524:54) +#loc346 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":531:43) +#loc347 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":533:15) +#loc348 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:30) +#loc349 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:21) +#loc350 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:10) +#loc351 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":537:11) +#loc352 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":537:4) +#loc353 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:41) +#loc354 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:52) +#loc355 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:23) +#loc356 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:15) +#loc358 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":762:21) +#loc359 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":762:11) +#loc360 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":762:4) +#loc362 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":752:33) +#loc363 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:38) +#loc364 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:24) +#loc365 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:109) +#loc366 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:113) +#loc367 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:39) +#loc368 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:55) +#loc369 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:25) +#loc370 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:30) +#loc371 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:35) +#loc372 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:60) +#loc373 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:34) +#loc374 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:48) +#loc375 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:63) +#loc376 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:29) +#loc377 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:47) +#loc378 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:61) +#loc379 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:42) +#loc380 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":758:11) +#loc381 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":758:4) +#loc383 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":580:26) +#loc384 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":581:26) +#loc385 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:26) +#loc386 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:37) +#loc387 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:18) +#loc388 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:56) +#loc389 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:67) +#loc390 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:49) +#loc391 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:27) +#loc392 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:38) +#loc393 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:19) +#loc394 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:58) +#loc395 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:69) +#loc396 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:51) +#loc397 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:42) +#loc398 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:87) +#loc399 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:98) +#loc400 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:61) +#loc401 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":592:28) +#loc402 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":600:12) +#loc403 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":605:62) +#loc404 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:28) +#loc405 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:19) +#loc406 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:28) +#loc407 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:19) +#loc408 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":610:19) +#loc409 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":610:8) +#loc410 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":612:11) +#loc411 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":612:4) +#loc413 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":651:105) +#loc414 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:52) +#loc415 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:28) +#loc416 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:22) +#loc417 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:26) +#loc418 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:46) +#loc419 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":658:20) +#loc420 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":660:15) +#loc421 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":662:36) +#loc422 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":662:46) +#loc423 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":665:36) +#loc424 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":665:46) +#loc425 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":674:43) +#loc426 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":674:54) +#loc427 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":674:78) +#loc428 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":678:32) +#loc429 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":679:24) +#loc430 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":681:25) +#loc431 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":682:24) +#loc432 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":683:25) +#loc433 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":684:24) +#loc434 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":685:24) +#loc435 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":686:25) +#loc436 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":687:24) +#loc437 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":688:24) +#loc438 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":689:33) +#loc439 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:34) +#loc440 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:49) +#loc441 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:41) +#loc442 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:70) +#loc443 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:79) +#loc444 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:91) +#loc445 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:99) +#loc446 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:111) +#loc447 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:102) +#loc448 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:128) +#loc449 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:119) +#loc450 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":691:24) +#loc451 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:34) +#loc452 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:49) +#loc453 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:41) +#loc454 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:70) +#loc455 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:79) +#loc456 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:91) +#loc457 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:99) +#loc458 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:111) +#loc459 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:102) +#loc460 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:128) +#loc461 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:119) +#loc462 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":693:25) +#loc463 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":694:24) +#loc464 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":695:24) +#loc465 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":696:24) +#loc466 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":700:69) +#loc467 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":703:27) +#loc468 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:44) +#loc469 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:40) +#loc470 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:22) +#loc471 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":705:99) +#loc472 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:24) +#loc473 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:43) +#loc474 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:10) +#loc475 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:53) +#loc476 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:29) +#loc477 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:21) +#loc478 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:29) +#loc479 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:20) +#loc480 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:25) +#loc481 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:22) +#loc482 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:16) +#loc483 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":723:39) +#loc484 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":723:50) +#loc485 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":723:70) +#loc486 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":737:45) +#loc487 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:24) +#loc488 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:52) +#loc489 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:43) +#loc490 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:10) +#loc491 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":741:11) +#loc492 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":741:4) +#loc518 = loc("ZQ"(#loc24)) +#loc519 = loc("HQ"(#loc25)) +#loc520 = loc("HKV"(#loc26)) +#loc521 = loc("ZKV"(#loc27)) +#loc522 = loc("pid"(#loc28)) +#loc523 = loc("NUM_KV_BLOCKS"(#loc29)) +#loc524 = loc("NUM_Q_BLOCKS"(#loc30)) +#loc525 = loc("off_zq"(#loc31)) +#loc526 = loc("off_hkv"(#loc32)) +#loc527 = loc("off_zkv"(#loc33)) +#loc528 = loc("SPARSE_Z"(#loc34)) +#loc529 = loc("SPARSE_HQ"(#loc35)) +#loc530 = loc("sparse_idx_z"(#loc36)) +#loc531 = loc("k_adj"(#loc37)) +#loc532 = loc("k_adj"(#loc38)) +#loc533 = loc("k_adj"(#loc39)) +#loc534 = loc("k_adj"(#loc40)) +#loc535 = loc("v_adj"(#loc41)) +#loc536 = loc("v_adj"(#loc42)) +#loc537 = loc("v_adj"(#loc43)) +#loc538 = loc("v_adj"(#loc44)) +#loc539 = loc("dv_adj"(#loc45)) +#loc540 = loc("dv_adj"(#loc46)) +#loc541 = loc("dv_adj"(#loc47)) +#loc542 = loc("dv_adj"(#loc48)) +#loc543 = loc("K"(#loc49)) +#loc544 = loc("V"(#loc50)) +#loc545 = loc("DV"(#loc51)) +#loc546 = loc("RCP_LN2"(#loc52)) +#loc547 = loc("offs_k"(#loc53)) +#loc548 = loc("offs_v"(#loc54)) +#loc549 = loc("off_pid"(#loc57)) +#loc550 = loc("SPARSE_Q_MULTIPLE"(#loc58)) +#loc551 = loc("SPARSE_KV_MULTIPLE"(#loc59)) +#loc552 = loc("off_hq2"(#loc60)) +#loc553 = loc("off_hq2"(#loc61)) +#loc554 = loc("off_hq2"(#loc62)) +#loc555 = loc("start_m2_block"(#loc63)) +#loc556 = loc("off_pid_mask"(#loc64)) +#loc557 = loc("stride_kv_idx_h"(#loc65)) +#loc558 = loc("sparse_idx_hq2"(#loc66)) +#loc559 = loc("sparse_hz_offset"(#loc67)) +#loc560 = loc("sparse_hz_offset"(#loc68)) +#loc561 = loc("sparse_kv_num_blks_offset"(#loc69)) +#loc562 = loc("sparse_kv_num_blks_offset"(#loc70)) +#loc563 = loc("sparse_kv_idx_offset"(#loc71)) +#loc564 = loc("sparse_kv_idx_offset"(#loc72)) +#loc565 = loc("sparse_kv_idx_offset"(#loc73)) +#loc566 = loc("q_adj2"(#loc74)) +#loc567 = loc("q_adj2"(#loc75)) +#loc568 = loc("q_adj2"(#loc76)) +#loc569 = loc("q_adj2"(#loc77)) +#loc570 = loc("do_adj2"(#loc78)) +#loc571 = loc("do_adj2"(#loc79)) +#loc572 = loc("do_adj2"(#loc80)) +#loc573 = loc("do_adj2"(#loc81)) +#loc574 = loc("dq_adj2"(#loc82)) +#loc575 = loc("dq_adj2"(#loc83)) +#loc576 = loc("dq_adj2"(#loc84)) +#loc577 = loc("dq_adj2"(#loc85)) +#loc578 = loc("off_chz2"(#loc86)) +#loc579 = loc("off_chz2"(#loc87)) +#loc580 = loc("off_chz2"(#loc88)) +#loc581 = loc("off_chz2"(#loc89)) +#loc582 = loc("Q2"(#loc90)) +#loc583 = loc("DO2"(#loc91)) +#loc584 = loc("DQ2"(#loc92)) +#loc585 = loc("LSE2"(#loc93)) +#loc586 = loc("DELTA2"(#loc94)) +#loc587 = loc("dq"(#loc95)) +#loc588 = loc("start_m2"(#loc96)) +#loc589 = loc("offs_m2"(#loc97)) +#loc590 = loc("offs_m2"(#loc98)) +#loc591 = loc("q"(#loc99)) +#loc592 = loc("do"(#loc100)) +#loc593 = loc("Di"(#loc101)) +#loc594 = loc("Di"(#loc102)) +#loc595 = loc("Di"(#loc103)) +#loc596 = loc("lse"(#loc104)) +#loc597 = loc("lse"(#loc105)) +#loc598 = loc("lse"(#loc106)) +#loc599 = loc("lse"(#loc107)) +#loc600 = loc("lse"(#loc108)) +#loc601 = loc("lse"(#loc109)) +#loc602 = loc("kv_indices"(#loc110)) +#loc603 = loc("kv_start"(#loc111)) +#loc604 = loc("kv_start"(#loc112)) +#loc605 = loc("sparse_kv_num_blocks"(#loc113)) +#loc606 = loc("sparse_kv_num_blocks"(#loc114)) +#loc607 = loc("offs_n2"(#loc115)) +#loc608 = loc("offs_n2"(#loc116)) +#loc609 = loc("dq"(#loc117)) +#loc610 = loc("kv_indices"(#loc118)) +#loc611 = loc("kv_start"(#loc119)) +#loc612 = loc("kv_start"(#loc120)) +#loc613 = loc("sparse_kv_num_blocks"(#loc121)) +#loc614 = loc("sparse_kv_num_blocks"(#loc122)) +#loc615 = loc("offs_n2"(#loc123)) +#loc616 = loc("offs_n2"(#loc124)) +#loc617 = loc("dq"(#loc125)) +#loc618 = loc("dq_ptrs"(#loc126)) +#loc619 = loc("dq_ptrs"(#loc127)) +#loc620 = loc("dq_ptrs"(#loc128)) +#loc621 = loc("dq_ptrs"(#loc129)) +#loc622 = loc("dq_ptrs"(#loc130)) +#loc623 = loc("dq_ptrs"(#loc131)) +#loc624 = loc("dq"(#loc132)) +#loc625 = loc("SPARSE_Q_MULTIPLE"(#loc139)) +#loc626 = loc("SPARSE_KV_MULTIPLE"(#loc140)) +#loc627 = loc("pid_mask"(#loc141)) +#loc628 = loc("stride_q_idx_h"(#loc142)) +#loc629 = loc("dv"(#loc143)) +#loc630 = loc("dk"(#loc144)) +#loc631 = loc("start_n1"(#loc145)) +#loc632 = loc("offs_n1"(#loc146)) +#loc633 = loc("offs_n1"(#loc147)) +#loc634 = loc("k"(#loc148)) +#loc635 = loc("v"(#loc149)) +#loc636 = loc("dv"(#loc150)) +#loc637 = loc("off_hq1"(#loc151)) +#loc638 = loc("off_hq1"(#loc152)) +#loc639 = loc("q_adj1"(#loc153)) +#loc640 = loc("q_adj1"(#loc154)) +#loc641 = loc("q_adj1"(#loc155)) +#loc642 = loc("q_adj1"(#loc156)) +#loc643 = loc("do_adj1"(#loc157)) +#loc644 = loc("do_adj1"(#loc158)) +#loc645 = loc("do_adj1"(#loc159)) +#loc646 = loc("do_adj1"(#loc160)) +#loc647 = loc("dq_adj1"(#loc161)) +#loc648 = loc("dq_adj1"(#loc162)) +#loc649 = loc("dq_adj1"(#loc163)) +#loc650 = loc("dq_adj1"(#loc164)) +#loc651 = loc("off_chz1"(#loc165)) +#loc652 = loc("off_chz1"(#loc166)) +#loc653 = loc("off_chz1"(#loc167)) +#loc654 = loc("off_chz1"(#loc168)) +#loc655 = loc("Q1"(#loc169)) +#loc656 = loc("DO1"(#loc170)) +#loc657 = loc("LSE1"(#loc171)) +#loc658 = loc("DELTA1"(#loc172)) +#loc659 = loc("sparse_idx_hq1"(#loc173)) +#loc660 = loc("sparse_hz_offset"(#loc174)) +#loc661 = loc("sparse_hz_offset"(#loc175)) +#loc662 = loc("sparse_q_num_blks_offset"(#loc176)) +#loc663 = loc("sparse_q_num_blks_offset"(#loc177)) +#loc664 = loc("sparse_q_idx_offset"(#loc178)) +#loc665 = loc("sparse_q_idx_offset"(#loc179)) +#loc666 = loc("sparse_q_idx_offset"(#loc180)) +#loc667 = loc("q_indices"(#loc181)) +#loc668 = loc("q_start"(#loc182)) +#loc669 = loc("q_start"(#loc183)) +#loc670 = loc("sparse_q_num_blocks"(#loc184)) +#loc671 = loc("sparse_q_num_blocks"(#loc185)) +#loc672 = loc("offs_m1"(#loc186)) +#loc673 = loc("offs_m1"(#loc187)) +#loc674 = loc("q_indices"(#loc189)) +#loc675 = loc("q_start"(#loc190)) +#loc676 = loc("q_start"(#loc191)) +#loc677 = loc("sparse_q_num_blocks"(#loc192)) +#loc678 = loc("sparse_q_num_blocks"(#loc193)) +#loc679 = loc("offs_m1"(#loc194)) +#loc680 = loc("offs_m1"(#loc195)) +#loc681 = loc("dv_ptrs"(#loc198)) +#loc682 = loc("dv_ptrs"(#loc199)) +#loc683 = loc("dv_ptrs"(#loc200)) +#loc684 = loc("dv_ptrs"(#loc201)) +#loc685 = loc("dv_ptrs"(#loc202)) +#loc686 = loc("dv_ptrs"(#loc203)) +#loc687 = loc("index_n"(#loc204)) +#loc688 = loc("index_k"(#loc205)) +#loc689 = loc("index_v"(#loc206)) +#loc690 = loc("dk"(#loc211)) +#loc691 = loc("mask"(#loc212)) +#loc692 = loc("xindex"(#loc213)) +#loc693 = loc("xindex"(#loc214)) +#loc694 = loc("xindex"(#loc215)) +#loc695 = loc("xindex"(#loc216)) +#loc696 = loc("xindex"(#loc217)) +#loc697 = loc("xindex"(#loc218)) +#loc698 = loc("xindex"(#loc219)) +#loc699 = loc("xindex"(#loc220)) +#loc707 = loc("ptr"(#loc239)) +#loc708 = loc("ptr"(#loc240)) +#loc709 = loc("ptr"(#loc241)) +#loc710 = loc("ptr"(#loc242)) +#loc711 = loc("ptr"(#loc243)) +#loc712 = loc("ptr"(#loc244)) +#loc755 = loc("offs_k"(#loc251)) +#loc756 = loc("offs_v"(#loc252)) +#loc757 = loc("kT_ptrs"(#loc253)) +#loc758 = loc("kT_ptrs"(#loc254)) +#loc759 = loc("kT_ptrs"(#loc255)) +#loc760 = loc("kT_ptrs"(#loc256)) +#loc761 = loc("kT_ptrs"(#loc257)) +#loc762 = loc("kT_ptrs"(#loc258)) +#loc763 = loc("vT_ptrs"(#loc259)) +#loc764 = loc("vT_ptrs"(#loc260)) +#loc765 = loc("vT_ptrs"(#loc261)) +#loc766 = loc("vT_ptrs"(#loc262)) +#loc767 = loc("vT_ptrs"(#loc263)) +#loc768 = loc("vT_ptrs"(#loc264)) +#loc769 = loc("hi"(#loc265)) +#loc770 = loc("hi"(#loc266)) +#loc771 = loc("hi"(#loc267)) +#loc772 = loc("hi"(#loc268)) +#loc773 = loc("dq"(#loc269)) +#loc774 = loc("dq"(#loc270)) +#loc775 = loc("offset"(#loc271)) +#loc776 = loc("kT_ptrs"(#loc272)) +#loc777 = loc("kT_ptrs"(#loc273)) +#loc778 = loc("vT_ptrs"(#loc274)) +#loc779 = loc("vT_ptrs"(#loc275)) +#loc780 = loc("offs_n2"(#loc276)) +#loc827 = loc("kT"(#loc281)) +#loc828 = loc("qk"(#loc282)) +#loc829 = loc("qk"(#loc283)) +#loc830 = loc("n"(#loc284)) +#loc831 = loc("n"(#loc285)) +#loc832 = loc("m"(#loc286)) +#loc833 = loc("m"(#loc287)) +#loc834 = loc("post_mod_scores"(#loc288)) +#loc835 = loc("post_mod_scores"(#loc289)) +#loc836 = loc("post_mod_scores"(#loc290)) +#loc837 = loc("tmp2"(#loc291)) +#loc838 = loc("tmp3"(#loc292)) +#loc839 = loc("tmp5"(#loc293)) +#loc840 = loc("tmp6"(#loc294)) +#loc841 = loc("tmp7"(#loc295)) +#loc842 = loc("tmp8"(#loc296)) +#loc843 = loc("tmp9"(#loc297)) +#loc844 = loc("tmp10"(#loc298)) +#loc845 = loc("tmp11"(#loc299)) +#loc846 = loc("tmp12"(#loc300)) +#loc847 = loc("tmp13"(#loc301)) +#loc848 = loc("tmp14"(#loc302)) +#loc849 = loc("tmp14"(#loc303)) +#loc850 = loc("tmp14"(#loc304)) +#loc851 = loc("tmp14"(#loc305)) +#loc852 = loc("tmp14"(#loc306)) +#loc853 = loc("tmp14"(#loc307)) +#loc854 = loc("tmp14"(#loc308)) +#loc855 = loc("tmp14"(#loc309)) +#loc856 = loc("tmp14"(#loc310)) +#loc857 = loc("tmp14"(#loc311)) +#loc858 = loc("tmp14"(#loc312)) +#loc859 = loc("tmp15"(#loc313)) +#loc860 = loc("tmp16"(#loc314)) +#loc861 = loc("tmp16"(#loc315)) +#loc862 = loc("tmp16"(#loc316)) +#loc863 = loc("tmp16"(#loc317)) +#loc864 = loc("tmp16"(#loc318)) +#loc865 = loc("tmp16"(#loc319)) +#loc866 = loc("tmp16"(#loc320)) +#loc867 = loc("tmp16"(#loc321)) +#loc868 = loc("tmp16"(#loc322)) +#loc869 = loc("tmp16"(#loc323)) +#loc870 = loc("tmp16"(#loc324)) +#loc871 = loc("tmp17"(#loc325)) +#loc872 = loc("tmp18"(#loc326)) +#loc873 = loc("tmp19"(#loc327)) +#loc874 = loc("tmp20"(#loc328)) +#loc875 = loc("post_mod_scores"(#loc329)) +#loc876 = loc("post_mod_scores"(#loc330)) +#loc877 = loc("p"(#loc331)) +#loc878 = loc("p"(#loc332)) +#loc879 = loc("vT"(#loc333)) +#loc880 = loc("dp"(#loc334)) +#loc881 = loc("ds"(#loc335)) +#loc882 = loc("ds"(#loc336)) +#loc883 = loc("ds"(#loc337)) +#loc884 = loc("grad_scores"(#loc338)) +#loc885 = loc("grad_scores"(#loc339)) +#loc886 = loc("grad_scores"(#loc340)) +#loc887 = loc("scatter_mask"(#loc341)) +#loc888 = loc("scatter_mask"(#loc342)) +#loc889 = loc("scatter_mask"(#loc343)) +#loc890 = loc("scatter_mask"(#loc344)) +#loc891 = loc("scatter_mask"(#loc345)) +#loc892 = loc("ds"(#loc346)) +#loc893 = loc("ds"(#loc347)) +#loc894 = loc("dq"(#loc348)) +#loc895 = loc("dq"(#loc349)) +#loc896 = loc("dq"(#loc350)) +#loc903 = loc("cur_block_idx"(#loc362)) +#loc904 = loc("cur_block"(#loc363)) +#loc905 = loc("cur_block"(#loc364)) +#loc906 = loc("next_block"(#loc365)) +#loc907 = loc("next_block"(#loc366)) +#loc908 = loc("next_block"(#loc367)) +#loc909 = loc("next_block"(#loc368)) +#loc910 = loc("next_block"(#loc369)) +#loc911 = loc("needs_jump"(#loc370)) +#loc912 = loc("needs_jump"(#loc371)) +#loc913 = loc("needs_jump"(#loc372)) +#loc914 = loc("jump_to_block"(#loc373)) +#loc915 = loc("jump_to_block"(#loc374)) +#loc916 = loc("jump_to_block"(#loc375)) +#loc917 = loc("offset"(#loc376)) +#loc918 = loc("offset"(#loc377)) +#loc919 = loc("offset"(#loc378)) +#loc920 = loc("offset"(#loc379)) +#loc964 = loc("offs_k"(#loc383)) +#loc965 = loc("offs_v"(#loc384)) +#loc966 = loc("qT_ptrs"(#loc385)) +#loc967 = loc("qT_ptrs"(#loc386)) +#loc968 = loc("qT_ptrs"(#loc387)) +#loc969 = loc("qT_ptrs"(#loc388)) +#loc970 = loc("qT_ptrs"(#loc389)) +#loc971 = loc("qT_ptrs"(#loc390)) +#loc972 = loc("do_ptrs"(#loc391)) +#loc973 = loc("do_ptrs"(#loc392)) +#loc974 = loc("do_ptrs"(#loc393)) +#loc975 = loc("do_ptrs"(#loc394)) +#loc976 = loc("do_ptrs"(#loc395)) +#loc977 = loc("do_ptrs"(#loc396)) +#loc978 = loc("hi"(#loc397)) +#loc979 = loc("hi"(#loc398)) +#loc980 = loc("hi"(#loc399)) +#loc981 = loc("hi"(#loc400)) +#loc982 = loc("dk"(#loc401)) +#loc983 = loc("offset"(#loc403)) +#loc984 = loc("qT_ptrs"(#loc404)) +#loc985 = loc("qT_ptrs"(#loc405)) +#loc986 = loc("do_ptrs"(#loc406)) +#loc987 = loc("do_ptrs"(#loc407)) +#loc988 = loc("offs_m1"(#loc408)) +#loc1036 = loc("qT"(#loc413)) +#loc1037 = loc("lse"(#loc414)) +#loc1038 = loc("lse"(#loc415)) +#loc1039 = loc("lse"(#loc416)) +#loc1040 = loc("lse"(#loc417)) +#loc1041 = loc("lse"(#loc418)) +#loc1042 = loc("qkT"(#loc419)) +#loc1043 = loc("qkT"(#loc420)) +#loc1044 = loc("m"(#loc421)) +#loc1045 = loc("m"(#loc422)) +#loc1046 = loc("n"(#loc423)) +#loc1047 = loc("n"(#loc424)) +#loc1048 = loc("post_mod_scores"(#loc425)) +#loc1049 = loc("post_mod_scores"(#loc426)) +#loc1050 = loc("post_mod_scores"(#loc427)) +#loc1051 = loc("tmp24"(#loc428)) +#loc1052 = loc("tmp25"(#loc429)) +#loc1053 = loc("tmp27"(#loc430)) +#loc1054 = loc("tmp28"(#loc431)) +#loc1055 = loc("tmp29"(#loc432)) +#loc1056 = loc("tmp30"(#loc433)) +#loc1057 = loc("tmp31"(#loc434)) +#loc1058 = loc("tmp32"(#loc435)) +#loc1059 = loc("tmp33"(#loc436)) +#loc1060 = loc("tmp34"(#loc437)) +#loc1061 = loc("tmp35"(#loc438)) +#loc1062 = loc("tmp36"(#loc439)) +#loc1063 = loc("tmp36"(#loc440)) +#loc1064 = loc("tmp36"(#loc441)) +#loc1065 = loc("tmp36"(#loc442)) +#loc1066 = loc("tmp36"(#loc443)) +#loc1067 = loc("tmp36"(#loc444)) +#loc1068 = loc("tmp36"(#loc445)) +#loc1069 = loc("tmp36"(#loc446)) +#loc1070 = loc("tmp36"(#loc447)) +#loc1071 = loc("tmp36"(#loc448)) +#loc1072 = loc("tmp36"(#loc449)) +#loc1073 = loc("tmp37"(#loc450)) +#loc1074 = loc("tmp38"(#loc451)) +#loc1075 = loc("tmp38"(#loc452)) +#loc1076 = loc("tmp38"(#loc453)) +#loc1077 = loc("tmp38"(#loc454)) +#loc1078 = loc("tmp38"(#loc455)) +#loc1079 = loc("tmp38"(#loc456)) +#loc1080 = loc("tmp38"(#loc457)) +#loc1081 = loc("tmp38"(#loc458)) +#loc1082 = loc("tmp38"(#loc459)) +#loc1083 = loc("tmp38"(#loc460)) +#loc1084 = loc("tmp38"(#loc461)) +#loc1085 = loc("tmp39"(#loc462)) +#loc1086 = loc("tmp40"(#loc463)) +#loc1087 = loc("tmp41"(#loc464)) +#loc1088 = loc("tmp42"(#loc465)) +#loc1089 = loc("post_mod_scores"(#loc466)) +#loc1090 = loc("post_mod_scores"(#loc467)) +#loc1091 = loc("pT"(#loc468)) +#loc1092 = loc("pT"(#loc469)) +#loc1093 = loc("pT"(#loc470)) +#loc1094 = loc("do"(#loc471)) +#loc1095 = loc("dv"(#loc472)) +#loc1096 = loc("dv"(#loc473)) +#loc1097 = loc("dv"(#loc474)) +#loc1098 = loc("Di"(#loc475)) +#loc1099 = loc("Di"(#loc476)) +#loc1100 = loc("Di"(#loc477)) +#loc1101 = loc("dpT"(#loc478)) +#loc1102 = loc("dpT"(#loc479)) +#loc1103 = loc("dsT"(#loc480)) +#loc1104 = loc("dsT"(#loc481)) +#loc1105 = loc("dsT"(#loc482)) +#loc1106 = loc("grad_scores"(#loc483)) +#loc1107 = loc("grad_scores"(#loc484)) +#loc1108 = loc("grad_scores"(#loc485)) +#loc1109 = loc("dsT"(#loc486)) +#loc1110 = loc("dk"(#loc487)) +#loc1111 = loc("dk"(#loc488)) +#loc1112 = loc("dk"(#loc489)) +#loc1113 = loc("dk"(#loc490)) +#loc1114 = loc("SPARSE_Q_MULTIPLE"(#loc550)) +#loc1115 = loc("SPARSE_KV_MULTIPLE"(#loc551)) +#loc1116 = loc("SPARSE_Q_MULTIPLE"(#loc625)) +#loc1117 = loc("SPARSE_KV_MULTIPLE"(#loc626)) +#loc1118 = loc("dk"(#loc636)) +#loc1119 = loc("offs_n2"(#loc773)) +#loc1120 = loc("dv"(#loc982)) +#loc1121 = loc("kT_ptrs"(#loc1119)) +#loc1122 = loc("offs_m1"(#loc1120)) +#loc1123 = loc("vT_ptrs"(#loc1121)) +#loc1124 = loc("qT_ptrs"(#loc1122)) +#loc1125 = loc("do_ptrs"(#loc1124)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttgir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..a82c1b4f042103bfb7f485b7ff858e864a04db3f --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttgir @@ -0,0 +1,1944 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":18:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#loc303 = loc("arg_Q"(#loc)) +#loc304 = loc("arg_K"(#loc)) +#loc305 = loc("arg_V"(#loc)) +#loc306 = loc("arg_LSE"(#loc)) +#loc307 = loc("arg_DELTA"(#loc)) +#loc308 = loc("arg_DO"(#loc)) +#loc309 = loc("arg_DQ"(#loc)) +#loc310 = loc("arg_DV"(#loc)) +#loc311 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc312 = loc("arg_KV_IDX"(#loc)) +#loc313 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc314 = loc("arg_Q_IDX"(#loc)) +#loc315 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc316 = loc("arg_FULL_KV_IDX"(#loc)) +#loc317 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc318 = loc("arg_FULL_Q_IDX"(#loc)) +#loc319 = loc("out_ptr0"(#loc)) +#loc320 = loc("ks0"(#loc)) +#loc321 = loc("ks1"(#loc)) +#loc322 = loc("ks2"(#loc)) +#loc323 = loc("ks3"(#loc)) +#loc324 = loc("ks4"(#loc)) +#loc325 = loc("ks5"(#loc)) +#loc326 = loc("ks6"(#loc)) +#loc327 = loc("ks7"(#loc)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc)), %ks2: i32 loc("ks2"(#loc)), %ks3: i32 loc("ks3"(#loc)), %ks4: i32 loc("ks4"(#loc)), %ks5: i32 loc("ks5"(#loc)), %ks6: i32 loc("ks6"(#loc)), %ks7: i32 loc("ks7"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<1024> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<1x128xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<128x1xi32, #blocked> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xbf16, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16, #blocked1> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x128xbf16, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_5 = arith.constant dense<0.0883883461> : tensor<128x128xf32, #mma> loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_8 = arith.constant dense<0.0883883461> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_9 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_10 = arith.constant dense<1.44269502> : tensor<128x64xf32, #mma1> loc(#loc1) + %true = arith.constant true loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %cst_11 = arith.constant dense<65536> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_12 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc1) + %cst_13 = arith.constant dense<262144> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_14 = arith.constant dense<8192> : tensor<64x128xi32, #blocked> loc(#loc1) + %cst_15 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_16 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc1) + %cst_17 = arith.constant dense<128> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_18 = arith.constant dense<4096> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_19 = arith.constant dense<1024> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_20 = arith.constant dense<128> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_21 = arith.constant dense<16> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_22 = arith.constant dense<16> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_23 = arith.constant dense<0xFF800000> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_24 = arith.constant dense<0.000000e+00> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_25 = arith.constant dense<1> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_26 = arith.constant dense<1> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_27 = arith.constant dense<0> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_28 = arith.constant dense<0> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_29 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc1) + %cst_30 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc1) + %0 = arith.muli %ks0, %c4096_i32 : i32 loc(#loc2) + %1 = arith.cmpi sle, %ks0, %c1_i32 : i32 loc(#loc3) + %2 = arith.extui %1 : i1 to i32 loc(#loc4) + %3 = arith.cmpi sgt, %ks0, %c1_i32 : i32 loc(#loc5) + %4 = arith.extui %3 : i1 to i32 loc(#loc6) + %5 = arith.muli %ks0, %4 : i32 loc(#loc6) + %6 = arith.addi %2, %5 : i32 loc(#loc7) + %7 = arith.muli %6, %c4096_i32 : i32 loc(#loc8) + %8 = arith.muli %6, %c128_i32 : i32 loc(#loc9) + %9 = arith.muli %ks1, %c1024_i32 : i32 loc(#loc10) + %pid = tt.get_program_id x : i32 loc(#loc328) + %NUM_KV_BLOCKS = arith.addi %ks1, %c127_i32 : i32 loc(#loc596) + %NUM_KV_BLOCKS_31 = arith.divsi %NUM_KV_BLOCKS, %c128_i32 : i32 loc(#loc597) + %NUM_Q_BLOCKS = arith.addi %ks0, %c127_i32 : i32 loc(#loc598) + %NUM_Q_BLOCKS_32 = arith.divsi %NUM_Q_BLOCKS, %c128_i32 : i32 loc(#loc599) + %off_zq = tt.get_program_id y : i32 loc(#loc331) + %off_hkv = tt.get_program_id z : i32 loc(#loc332) + %k_adj = arith.muli %off_hkv, %c128_i32 : i32 loc(#loc333) + %k_adj_33 = arith.extsi %k_adj : i32 to i64 loc(#loc334) + %dv_adj = arith.muli %9, %off_zq : i32 loc(#loc335) + %dv_adj_34 = arith.addi %k_adj, %dv_adj : i32 loc(#loc336) + %dv_adj_35 = arith.extsi %dv_adj_34 : i32 to i64 loc(#loc337) + %K = tt.addptr %arg_K, %k_adj_33 : !tt.ptr, i64 loc(#loc338) + %V = tt.addptr %arg_V, %k_adj_33 : !tt.ptr, i64 loc(#loc339) + %DV = tt.addptr %arg_DV, %dv_adj_35 : !tt.ptr, i64 loc(#loc340) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc341) + %offs_k_36 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc341) + %10 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS_31 : i32 loc(#loc27) + scf.if %10 { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS_31 : i32 loc(#loc342) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS_32 : i32 loc(#loc343) + %off_hq2_37 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc344) + %off_hq2_38 = arith.addi %off_hq2, %off_hq2_37 : i32 loc(#loc345) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS_32 : i32 loc(#loc346) + %sparse_kv_idx_offset = arith.muli %start_m2_block, %ks4 : i32 loc(#loc347) + %q_adj2 = arith.muli %off_hq2_38, %c128_i32 : i32 loc(#loc348) + %q_adj2_39 = arith.muli %0, %off_zq : i32 loc(#loc349) + %q_adj2_40 = arith.addi %q_adj2, %q_adj2_39 : i32 loc(#loc350) + %q_adj2_41 = arith.extsi %q_adj2_40 : i32 to i64 loc(#loc351) + %do_adj2 = arith.muli %8, %off_hq2_38 : i32 loc(#loc352) + %do_adj2_42 = arith.muli %7, %off_zq : i32 loc(#loc353) + %do_adj2_43 = arith.addi %do_adj2, %do_adj2_42 : i32 loc(#loc354) + %do_adj2_44 = arith.extsi %do_adj2_43 : i32 to i64 loc(#loc355) + %off_chz2 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc356) + %off_chz2_45 = arith.addi %off_chz2, %off_hq2_38 : i32 loc(#loc357) + %off_chz2_46 = arith.muli %off_chz2_45, %ks0 : i32 loc(#loc358) + %off_chz2_47 = arith.extsi %off_chz2_46 : i32 to i64 loc(#loc359) + %Q2 = tt.addptr %arg_Q, %q_adj2_41 : !tt.ptr, i64 loc(#loc360) + %DO2 = tt.addptr %arg_DO, %do_adj2_44 : !tt.ptr, i64 loc(#loc361) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_41 : !tt.ptr, i64 loc(#loc362) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_47 : !tt.ptr, i64 loc(#loc363) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_47 : !tt.ptr, i64 loc(#loc364) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc365) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc366) + %offs_m2_48 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc366) + %offs_m2_49 = arith.addi %offs_m2, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc366) + %offs_m2_50 = arith.addi %offs_m2_48, %offs_k_36 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc366) + %ptr = tt.expand_dims %offs_m2_49 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc600) + %ptr_51 = tt.expand_dims %offs_m2_50 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> loc(#loc600) + %ptr_52 = arith.muli %ptr, %cst_1 : tensor<128x1xi32, #blocked> loc(#loc601) + %ptr_53 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc602) + %ptr_54 = tt.addptr %ptr_53, %ptr_52 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc602) + %ptr_55 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc603) + %ptr_56 = tt.expand_dims %ptr_55 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc603) + %ptr_57 = tt.broadcast %ptr_54 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc604) + %ptr_58 = tt.broadcast %ptr_56 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc604) + %ptr_59 = tt.addptr %ptr_57, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc604) + %q = tt.splat %ks0 : i32 -> tensor<128x1xi32, #blocked> loc(#loc605) + %q_60 = tt.splat %ks0 : i32 -> tensor<128x1xi32, #mma1> loc(#loc605) + %q_61 = arith.cmpi slt, %ptr, %q : tensor<128x1xi32, #blocked> loc(#loc605) + %q_62 = tt.broadcast %q_61 : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc606) + %q_63 = tt.load %ptr_59, %q_62, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc606) + %q_64 = ttg.local_alloc %q_63 : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc606) + %ptr_65 = arith.muli %ptr, %cst_20 : tensor<128x1xi32, #blocked> loc(#loc607) + %ptr_66 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc608) + %ptr_67 = tt.addptr %ptr_66, %ptr_65 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc608) + %ptr_68 = tt.broadcast %ptr_67 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc609) + %ptr_69 = tt.addptr %ptr_68, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc609) + %do = tt.load %ptr_69, %q_62, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc610) + %do_70 = ttg.local_alloc %do : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc610) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc374) + %Di_71 = arith.cmpi slt, %offs_m2_50, %Di : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc374) + %Di_72 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc375) + %Di_73 = tt.addptr %Di_72, %offs_m2_50 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc375) + %Di_74 = tt.load %Di_73, %Di_71 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc376) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc377) + %lse_75 = tt.addptr %lse, %offs_m2_50 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc377) + %lse_76 = tt.load %lse_75, %Di_71 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc378) + %lse_77 = arith.cmpf oeq, %lse_76, %cst_30 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc379) + %lse_78 = arith.select %lse_77, %cst_29, %lse_76 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc380) + %lse_79 = tt.expand_dims %lse_78 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> loc(#loc381) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset : !tt.ptr, i32 loc(#loc382) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc383) + %kv_start_80 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc384) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc385) + %sparse_kv_num_blocks_81 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc386) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc387) + %offs_n2_82 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc387) + %offs_n2_83 = tt.splat %kv_start_80 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc388) + %offs_n2_84 = tt.splat %kv_start_80 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc388) + %offs_n2_85 = arith.addi %offs_n2_83, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc388) + %offs_n2_86 = arith.addi %offs_n2_84, %offs_n2_82 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc388) + %kT_ptrs = tt.expand_dims %offs_n2_86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc611) + %kT_ptrs_87 = arith.muli %kT_ptrs, %cst_19 : tensor<1x64xi32, #blocked1> loc(#loc612) + %kT_ptrs_88 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc613) + %kT_ptrs_89 = tt.addptr %kT_ptrs_88, %kT_ptrs_87 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc613) + %kT_ptrs_90 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc614) + %kT_ptrs_91 = tt.expand_dims %kT_ptrs_90 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc614) + %kT_ptrs_92 = tt.broadcast %kT_ptrs_89 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc615) + %kT_ptrs_93 = tt.broadcast %kT_ptrs_91 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc615) + %kT_ptrs_94 = tt.addptr %kT_ptrs_92, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc615) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc616) + %vT_ptrs_95 = tt.addptr %vT_ptrs, %kT_ptrs_87 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc616) + %vT_ptrs_96 = tt.broadcast %vT_ptrs_95 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc617) + %vT_ptrs_97 = tt.addptr %vT_ptrs_96, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc617) + %hi = arith.muli %sparse_kv_num_blocks_81, %c2_i32 : i32 loc(#loc618) + %hi_98 = arith.addi %ks1, %c63_i32 : i32 loc(#loc772) + %hi_99 = arith.divsi %hi_98, %c64_i32 : i32 loc(#loc773) + %hi_100 = arith.maxsi %hi_99, %c1_i32 : i32 loc(#loc620) + %hi_101 = arith.minsi %hi, %hi_100 : i32 loc(#loc621) + %kT = tt.splat %ks1 : i32 -> tensor<1x64xi32, #mma1> loc(#loc919) + %kT_102 = tt.splat %ks1 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc919) + %m = arith.remsi %ptr_51, %q_60 : tensor<128x1xi32, #mma1> loc(#loc920) + %tmp3 = arith.cmpi slt, %m, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc776) + %tmp5 = tt.broadcast %m : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc777) + %tmp6 = tt.broadcast %tmp3 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc778) + %tmp7 = arith.cmpi sge, %m, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc779) + %tmp9 = tt.broadcast %tmp7 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc780) + %tmp14 = arith.remsi %m, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc781) + %tmp14_103 = arith.cmpi ne, %tmp14, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc782) + %tmp14_104 = arith.divsi %m, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc783) + %tmp14_105 = arith.subi %tmp14_104, %cst_26 : tensor<128x1xi32, #mma1> loc(#loc784) + %tmp14_106 = arith.select %tmp14_103, %tmp14_105, %tmp14_104 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc785) + %tmp14_107 = arith.select %tmp3, %tmp14_106, %tmp14_104 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc786) + %tmp17 = tt.broadcast %tmp14_107 : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc787) + %p = tt.broadcast %lse_79 : tensor<128x1xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc788) + %ds = tt.expand_dims %Di_74 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> loc(#loc789) + %ds_108 = tt.broadcast %ds : tensor<128x1xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc790) + %kT_109 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc921) + %vT = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc922) + %vT_ptrs_110 = arith.cmpi sgt, %hi_101, %c0_i32 : i32 loc(#loc931) + %kT_111 = arith.cmpi slt, %kT_ptrs, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc919) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_113 = ttg.memdesc_index %kT_109[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %vT_ptrs_114 = tt.splat %vT_ptrs_110 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc931) + %vT_ptrs_115 = arith.andi %vT_ptrs_114, %kT_112 : tensor<128x64xi1, #blocked1> loc(#loc931) + %kT_116 = ttg.async_copy_global_to_local %kT_ptrs_94, %kT_113 mask %vT_ptrs_115 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %kT_117 = ttg.async_commit_group tokens %kT_116 loc(#loc921) + %vT_118 = ttg.memdesc_index %vT[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_119 = ttg.async_copy_global_to_local %vT_ptrs_97, %vT_118 mask %vT_ptrs_115 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_120 = ttg.async_commit_group tokens %vT_119 loc(#loc922) + %vT_ptrs_121 = arith.cmpi sgt, %hi_101, %c1_i32 : i32 loc(#loc931) + %kT_ptrs_122 = tt.addptr %kT_ptrs_94, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc624) + %vT_ptrs_123 = tt.addptr %vT_ptrs_97, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc625) + %offs_n2_124 = arith.addi %offs_n2_86, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc626) + %kT_125 = tt.expand_dims %offs_n2_124 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc924) + %kT_126 = arith.cmpi slt, %kT_125, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc919) + %kT_127 = tt.broadcast %kT_126 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_128 = ttg.memdesc_index %kT_109[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %vT_ptrs_129 = tt.splat %vT_ptrs_121 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc931) + %vT_ptrs_130 = arith.andi %vT_ptrs_129, %kT_127 : tensor<128x64xi1, #blocked1> loc(#loc931) + %kT_131 = ttg.async_copy_global_to_local %kT_ptrs_122, %kT_128 mask %vT_ptrs_130 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %kT_132 = ttg.async_commit_group tokens %kT_131 loc(#loc921) + %vT_133 = ttg.memdesc_index %vT[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_134 = ttg.async_copy_global_to_local %vT_ptrs_123, %vT_133 mask %vT_ptrs_130 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_135 = ttg.async_commit_group tokens %vT_134 loc(#loc922) + ttng.fence_async_shared {bCluster = false} loc(#loc793) + %vT_ptrs_136:12 = scf.for %vT_ptrs_192 = %c0_i32 to %hi_101 step %c1_i32 iter_args(%arg26 = %cst_6, %kT_ptrs_193 = %kT_ptrs_122, %offs_n2_194 = %offs_n2_124, %vT_ptrs_195 = %vT_ptrs_123, %offs_n2_196 = %offs_n2_85, %arg31 = %c1_i32, %arg32 = %c-1_i32, %kT_197 = %kT_117, %kT_198 = %kT_132, %vT_199 = %vT_120, %vT_200 = %vT_135, %arg37 = %c64_i32) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32) : i32 { + %vT_ptrs_201 = arith.subi %hi_101, %c2_i32 : i32 loc(#loc931) + %vT_ptrs_202 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_201 : i32 loc(#loc931) + %vT_ptrs_203 = arith.subi %hi_101, %c1_i32 : i32 loc(#loc931) + %vT_ptrs_204 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_203 : i32 loc(#loc931) + %vT_ptrs_205 = arith.addi %arg32, %c1_i32 : i32 loc(#loc931) + %vT_ptrs_206 = arith.cmpi sge, %vT_ptrs_205, %c3_i32 : i32 loc(#loc931) + %vT_ptrs_207 = arith.select %vT_ptrs_206, %c0_i32, %vT_ptrs_205 : i32 loc(#loc931) + %kT_208 = tt.expand_dims %offs_n2_196 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc924) + %kT_209 = arith.cmpi slt, %kT_208, %kT : tensor<1x64xi32, #mma1> loc(#loc919) + %kT_210 = tt.broadcast %kT_209 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc921) + %kT_211 = ttg.async_wait %kT_197, %vT_199 {num = 2 : i32} loc(#loc921) + %kT_212 = ttg.memdesc_index %kT_109[%vT_ptrs_207] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %dq_213 = ttg.memdesc_trans %kT_212 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc794) + %qk = ttng.warp_group_dot %q_64, %kT_212, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc793) + %qk_214:3 = ttng.warp_group_dot_wait %qk, %q_64, %kT_212 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc793) + %qk_215 = arith.mulf %qk_214#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc795) + %n = arith.remsi %kT_208, %kT : tensor<1x64xi32, #mma1> loc(#loc925) + %post_mod_scores = arith.select %kT_210, %qk_215, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc797) + %tmp5_216 = tt.broadcast %n : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc777) + %tmp5_217 = arith.cmpi sle, %tmp5_216, %tmp5 : tensor<128x64xi32, #mma1> loc(#loc777) + %tmp6_218 = arith.andi %tmp6, %tmp5_217 : tensor<128x64xi1, #mma1> loc(#loc778) + %tmp8 = arith.cmpi slt, %n, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc798) + %tmp9_219 = tt.broadcast %tmp8 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc780) + %tmp9_220 = arith.andi %tmp9, %tmp9_219 : tensor<128x64xi1, #mma1> loc(#loc780) + %tmp10 = arith.extui %tmp8 : tensor<1x64xi1, #mma1> to tensor<1x64xi32, #mma1> loc(#loc799) + %tmp10_221 = arith.cmpi eq, %tmp10, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc799) + %tmp11 = tt.broadcast %tmp10_221 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc800) + %tmp11_222 = arith.andi %tmp9, %tmp11 : tensor<128x64xi1, #mma1> loc(#loc800) + %tmp16 = arith.remsi %n, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc801) + %tmp16_223 = arith.cmpi ne, %tmp16, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc802) + %tmp16_224 = arith.divsi %n, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc803) + %tmp16_225 = arith.subi %tmp16_224, %cst_25 : tensor<1x64xi32, #mma1> loc(#loc804) + %tmp16_226 = arith.select %tmp16_223, %tmp16_225, %tmp16_224 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc805) + %tmp16_227 = arith.select %tmp8, %tmp16_226, %tmp16_224 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc806) + %tmp17_228 = tt.broadcast %tmp16_227 : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc787) + %tmp17_229 = arith.cmpi eq, %tmp17, %tmp17_228 : tensor<128x64xi32, #mma1> loc(#loc787) + %tmp18 = arith.andi %tmp11_222, %tmp17_229 : tensor<128x64xi1, #mma1> loc(#loc807) + %tmp19 = arith.ori %tmp9_220, %tmp18 : tensor<128x64xi1, #mma1> loc(#loc808) + %tmp20 = arith.ori %tmp6_218, %tmp19 : tensor<128x64xi1, #mma1> loc(#loc809) + %post_mod_scores_230 = arith.select %tmp20, %post_mod_scores, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc810) + %post_mod_scores_231 = arith.mulf %post_mod_scores_230, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc811) + %p_232 = arith.subf %post_mod_scores_231, %p : tensor<128x64xf32, #mma1> loc(#loc788) + %p_233 = math.exp2 %p_232 : tensor<128x64xf32, #mma1> loc(#loc812) + %vT_234 = ttg.memdesc_index %vT[%vT_ptrs_207] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %dp = ttng.warp_group_dot %do_70, %vT_234, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc813) + %dp_235:3 = ttng.warp_group_dot_wait %dp, %do_70, %vT_234 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc813) + %ds_236 = arith.subf %dp_235#0, %ds_108 : tensor<128x64xf32, #mma1> loc(#loc790) + %ds_237 = arith.mulf %p_233, %ds_236 : tensor<128x64xf32, #mma1> loc(#loc814) + %grad_scores = arith.select %kT_210, %ds_237, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc815) + %ds_238 = arith.select %tmp20, %grad_scores, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc816) + %ds_239 = arith.truncf %ds_238 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc817) + %ds_240 = ttg.convert_layout %ds_239 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc817) + %dq_241 = ttng.warp_group_dot %ds_240, %dq_213, %arg26 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc818) + %offs_n2_242 = tt.splat %arg37 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc626) + %offs_n2_243 = arith.addi %offs_n2_196, %offs_n2_242 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc626) + %vT_ptrs_244 = arith.addi %vT_ptrs_192, %c1_i32 : i32 loc(#loc931) + %cur_block_idx = arith.divsi %vT_ptrs_244, %c2_i32 : i32 loc(#loc819) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc820) + %cur_block_245 = tt.load %cur_block, %vT_ptrs_204 evictionPolicy = evict_last : !tt.ptr loc(#loc821) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc822) + %next_block_246 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_81 : i32 loc(#loc823) + %next_block_247 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc824) + %vT_ptrs_248 = arith.andi %vT_ptrs_204, %next_block_246 : i1 loc(#loc931) + %next_block_249 = tt.load %next_block_247, %vT_ptrs_248 evictionPolicy = evict_last : !tt.ptr loc(#loc825) + %needs_jump = arith.addi %vT_ptrs_192, %c2_i32 : i32 loc(#loc826) + %needs_jump_250 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc827) + %needs_jump_251 = arith.cmpi eq, %needs_jump_250, %c0_i32 : i32 loc(#loc828) + %jump_to_block = arith.subi %next_block_249, %cur_block_245 : i32 loc(#loc829) + %jump_to_block_252 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc830) + %jump_to_block_253 = arith.subi %jump_to_block_252, %c64_i32 : i32 loc(#loc831) + %offset = arith.extui %needs_jump_251 : i1 to i32 loc(#loc832) + %offset_254 = arith.muli %jump_to_block_253, %offset : i32 loc(#loc832) + %offset_255 = arith.subi %c1_i32, %offset : i32 loc(#loc833) + %offset_256 = arith.muli %offset_255, %c64_i32 : i32 loc(#loc834) + %offset_257 = arith.addi %offset_254, %offset_256 : i32 loc(#loc835) + %kT_ptrs_258 = arith.muli %offset_257, %c1024_i32 : i32 loc(#loc628) + %kT_ptrs_259 = tt.splat %kT_ptrs_258 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc624) + %kT_ptrs_260 = tt.addptr %kT_ptrs_193, %kT_ptrs_259 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc624) + %vT_ptrs_261 = tt.addptr %vT_ptrs_195, %kT_ptrs_259 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc625) + %offs_n2_262 = tt.splat %offset_257 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc626) + %offs_n2_263 = arith.addi %offs_n2_194, %offs_n2_262 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc626) + %vT_ptrs_264 = arith.addi %arg31, %c1_i32 : i32 loc(#loc931) + %vT_ptrs_265 = arith.cmpi sge, %vT_ptrs_264, %c3_i32 : i32 loc(#loc931) + %vT_ptrs_266 = arith.select %vT_ptrs_265, %c0_i32, %vT_ptrs_264 : i32 loc(#loc931) + %kT_267 = tt.expand_dims %offs_n2_263 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc924) + %kT_268 = arith.cmpi slt, %kT_267, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc919) + %kT_269 = tt.broadcast %kT_268 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_270 = ttg.memdesc_index %kT_109[%vT_ptrs_266] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %vT_ptrs_271 = tt.splat %vT_ptrs_202 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc931) + %vT_ptrs_272 = arith.andi %vT_ptrs_271, %kT_269 : tensor<128x64xi1, #blocked1> loc(#loc931) + %kT_273 = ttg.async_copy_global_to_local %kT_ptrs_260, %kT_270 mask %vT_ptrs_272 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc921) + %kT_274 = ttg.async_commit_group tokens %kT_273 loc(#loc921) + %vT_275 = ttg.memdesc_index %vT[%vT_ptrs_266] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_276 = ttg.async_copy_global_to_local %vT_ptrs_261, %vT_275 mask %vT_ptrs_272 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc922) + %vT_277 = ttg.async_commit_group tokens %vT_276 loc(#loc922) + scf.yield %dq_241, %kT_ptrs_260, %offs_n2_263, %vT_ptrs_261, %offs_n2_243, %vT_ptrs_266, %vT_ptrs_207, %kT_198, %kT_274, %vT_200, %vT_277, %offset_257 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32 loc(#loc931) + } loc(#loc931) + %vT_ptrs_137 = ttng.warp_group_dot_wait %vT_ptrs_136#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> loc(#loc931) + %vT_ptrs_138 = ttg.async_wait {num = 0 : i32} loc(#loc931) + ttg.local_dealloc %vT : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc931) + ttg.local_dealloc %kT_109 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc931) + %kv_indices_139 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset : !tt.ptr, i32 loc(#loc469) + %kv_start_140 = tt.load %kv_indices_139 : !tt.ptr loc(#loc470) + %kv_start_141 = arith.muli %kv_start_140, %c128_i32 : i32 loc(#loc471) + %sparse_kv_num_blocks_142 = tt.addptr %arg_FULL_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc472) + %sparse_kv_num_blocks_143 = tt.load %sparse_kv_num_blocks_142 : !tt.ptr loc(#loc473) + %offs_n2_144 = tt.splat %kv_start_141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc474) + %offs_n2_145 = tt.splat %kv_start_141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc474) + %offs_n2_146 = arith.addi %offs_n2_144, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc474) + %offs_n2_147 = arith.addi %offs_n2_145, %offs_n2_82 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc474) + %kT_ptrs_148 = tt.expand_dims %offs_n2_147 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc629) + %kT_ptrs_149 = arith.muli %kT_ptrs_148, %cst_19 : tensor<1x64xi32, #blocked1> loc(#loc630) + %kT_ptrs_150 = tt.addptr %kT_ptrs_88, %kT_ptrs_149 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc631) + %kT_ptrs_151 = tt.broadcast %kT_ptrs_150 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc632) + %kT_ptrs_152 = tt.addptr %kT_ptrs_151, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc632) + %vT_ptrs_153 = tt.addptr %vT_ptrs, %kT_ptrs_149 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc633) + %vT_ptrs_154 = tt.broadcast %vT_ptrs_153 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc634) + %vT_ptrs_155 = tt.addptr %vT_ptrs_154, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc634) + %hi_156 = arith.muli %sparse_kv_num_blocks_143, %c2_i32 : i32 loc(#loc635) + %hi_157 = arith.minsi %hi_156, %hi_100 : i32 loc(#loc636) + %kT_158 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc926) + %vT_159 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc927) + %vT_ptrs_160 = arith.cmpi sgt, %hi_157, %c0_i32 : i32 loc(#loc932) + %kT_161 = arith.cmpi slt, %kT_ptrs_148, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc928) + %kT_162 = tt.broadcast %kT_161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc926) + %kT_163 = ttg.memdesc_index %kT_158[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %vT_ptrs_164 = tt.splat %vT_ptrs_160 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc932) + %vT_ptrs_165 = arith.andi %vT_ptrs_164, %kT_162 : tensor<128x64xi1, #blocked1> loc(#loc932) + %kT_166 = ttg.async_copy_global_to_local %kT_ptrs_152, %kT_163 mask %vT_ptrs_165 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %kT_167 = ttg.async_commit_group tokens %kT_166 loc(#loc926) + %vT_168 = ttg.memdesc_index %vT_159[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_169 = ttg.async_copy_global_to_local %vT_ptrs_155, %vT_168 mask %vT_ptrs_165 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_170 = ttg.async_commit_group tokens %vT_169 loc(#loc927) + %vT_ptrs_171 = arith.cmpi sgt, %hi_157, %c1_i32 : i32 loc(#loc932) + %kT_ptrs_172 = tt.addptr %kT_ptrs_152, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc638) + %vT_ptrs_173 = tt.addptr %vT_ptrs_155, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc639) + %offs_n2_174 = arith.addi %offs_n2_147, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc640) + %kT_175 = tt.expand_dims %offs_n2_174 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc929) + %kT_176 = arith.cmpi slt, %kT_175, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc928) + %kT_177 = tt.broadcast %kT_176 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc926) + %kT_178 = ttg.memdesc_index %kT_158[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %vT_ptrs_179 = tt.splat %vT_ptrs_171 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc932) + %vT_ptrs_180 = arith.andi %vT_ptrs_179, %kT_177 : tensor<128x64xi1, #blocked1> loc(#loc932) + %kT_181 = ttg.async_copy_global_to_local %kT_ptrs_172, %kT_178 mask %vT_ptrs_180 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %kT_182 = ttg.async_commit_group tokens %kT_181 loc(#loc926) + %vT_183 = ttg.memdesc_index %vT_159[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_184 = ttg.async_copy_global_to_local %vT_ptrs_173, %vT_183 mask %vT_ptrs_180 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_185 = ttg.async_commit_group tokens %vT_184 loc(#loc927) + ttng.fence_async_shared {bCluster = false} loc(#loc838) + %vT_ptrs_186:12 = scf.for %vT_ptrs_192 = %c0_i32 to %hi_157 step %c1_i32 iter_args(%vT_ptrs_193 = %vT_ptrs_137, %kT_ptrs_194 = %kT_ptrs_172, %offs_n2_195 = %offs_n2_174, %vT_ptrs_196 = %vT_ptrs_173, %offs_n2_197 = %offs_n2_146, %arg31 = %c1_i32, %arg32 = %c-1_i32, %kT_198 = %kT_167, %kT_199 = %kT_182, %vT_200 = %vT_170, %vT_201 = %vT_185, %arg37 = %c64_i32) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32) : i32 { + %vT_ptrs_202 = arith.subi %hi_157, %c2_i32 : i32 loc(#loc932) + %vT_ptrs_203 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_202 : i32 loc(#loc932) + %vT_ptrs_204 = arith.subi %hi_157, %c1_i32 : i32 loc(#loc932) + %vT_ptrs_205 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_204 : i32 loc(#loc932) + %vT_ptrs_206 = arith.addi %arg32, %c1_i32 : i32 loc(#loc932) + %vT_ptrs_207 = arith.cmpi sge, %vT_ptrs_206, %c3_i32 : i32 loc(#loc932) + %vT_ptrs_208 = arith.select %vT_ptrs_207, %c0_i32, %vT_ptrs_206 : i32 loc(#loc932) + %kT_209 = tt.expand_dims %offs_n2_197 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc929) + %kT_210 = arith.cmpi slt, %kT_209, %kT : tensor<1x64xi32, #mma1> loc(#loc928) + %kT_211 = tt.broadcast %kT_210 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc926) + %kT_212 = ttg.async_wait %kT_198, %vT_200 {num = 2 : i32} loc(#loc926) + %kT_213 = ttg.memdesc_index %kT_158[%vT_ptrs_208] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %dq_214 = ttg.memdesc_trans %kT_213 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc839) + %qk = ttng.warp_group_dot %q_64, %kT_213, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc838) + %qk_215:3 = ttng.warp_group_dot_wait %qk, %q_64, %kT_213 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc838) + %qk_216 = arith.mulf %qk_215#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc840) + %post_mod_scores = arith.select %kT_211, %qk_216, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc841) + %post_mod_scores_217 = arith.mulf %post_mod_scores, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc842) + %p_218 = arith.subf %post_mod_scores_217, %p : tensor<128x64xf32, #mma1> loc(#loc843) + %p_219 = math.exp2 %p_218 : tensor<128x64xf32, #mma1> loc(#loc844) + %vT_220 = ttg.memdesc_index %vT_159[%vT_ptrs_208] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %dp = ttng.warp_group_dot %do_70, %vT_220, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc845) + %dp_221:3 = ttng.warp_group_dot_wait %dp, %do_70, %vT_220 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc845) + %ds_222 = arith.subf %dp_221#0, %ds_108 : tensor<128x64xf32, #mma1> loc(#loc846) + %ds_223 = arith.mulf %p_219, %ds_222 : tensor<128x64xf32, #mma1> loc(#loc847) + %grad_scores = arith.select %kT_211, %ds_223, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc848) + %ds_224 = arith.truncf %grad_scores : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc849) + %ds_225 = ttg.convert_layout %ds_224 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc849) + %dq_226 = ttng.warp_group_dot %ds_225, %dq_214, %vT_ptrs_193 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc850) + %offs_n2_227 = tt.splat %arg37 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc640) + %offs_n2_228 = arith.addi %offs_n2_197, %offs_n2_227 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc640) + %vT_ptrs_229 = arith.addi %vT_ptrs_192, %c1_i32 : i32 loc(#loc932) + %cur_block_idx = arith.divsi %vT_ptrs_229, %c2_i32 : i32 loc(#loc851) + %cur_block = tt.addptr %kv_indices_139, %cur_block_idx : !tt.ptr, i32 loc(#loc852) + %cur_block_230 = tt.load %cur_block, %vT_ptrs_205 evictionPolicy = evict_last : !tt.ptr loc(#loc853) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc854) + %next_block_231 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_143 : i32 loc(#loc855) + %next_block_232 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc856) + %vT_ptrs_233 = arith.andi %vT_ptrs_205, %next_block_231 : i1 loc(#loc932) + %next_block_234 = tt.load %next_block_232, %vT_ptrs_233 evictionPolicy = evict_last : !tt.ptr loc(#loc857) + %needs_jump = arith.addi %vT_ptrs_192, %c2_i32 : i32 loc(#loc858) + %needs_jump_235 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc859) + %needs_jump_236 = arith.cmpi eq, %needs_jump_235, %c0_i32 : i32 loc(#loc860) + %jump_to_block = arith.subi %next_block_234, %cur_block_230 : i32 loc(#loc861) + %jump_to_block_237 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc862) + %jump_to_block_238 = arith.subi %jump_to_block_237, %c64_i32 : i32 loc(#loc863) + %offset = arith.extui %needs_jump_236 : i1 to i32 loc(#loc864) + %offset_239 = arith.muli %jump_to_block_238, %offset : i32 loc(#loc864) + %offset_240 = arith.subi %c1_i32, %offset : i32 loc(#loc865) + %offset_241 = arith.muli %offset_240, %c64_i32 : i32 loc(#loc866) + %offset_242 = arith.addi %offset_239, %offset_241 : i32 loc(#loc867) + %kT_ptrs_243 = arith.muli %offset_242, %c1024_i32 : i32 loc(#loc642) + %kT_ptrs_244 = tt.splat %kT_ptrs_243 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc638) + %kT_ptrs_245 = tt.addptr %kT_ptrs_194, %kT_ptrs_244 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc638) + %vT_ptrs_246 = tt.addptr %vT_ptrs_196, %kT_ptrs_244 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc639) + %offs_n2_247 = tt.splat %offset_242 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc640) + %offs_n2_248 = arith.addi %offs_n2_195, %offs_n2_247 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc640) + %vT_ptrs_249 = arith.addi %arg31, %c1_i32 : i32 loc(#loc932) + %vT_ptrs_250 = arith.cmpi sge, %vT_ptrs_249, %c3_i32 : i32 loc(#loc932) + %vT_ptrs_251 = arith.select %vT_ptrs_250, %c0_i32, %vT_ptrs_249 : i32 loc(#loc932) + %kT_252 = tt.expand_dims %offs_n2_248 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc929) + %kT_253 = arith.cmpi slt, %kT_252, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc928) + %kT_254 = tt.broadcast %kT_253 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc926) + %kT_255 = ttg.memdesc_index %kT_158[%vT_ptrs_251] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %vT_ptrs_256 = tt.splat %vT_ptrs_203 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc932) + %vT_ptrs_257 = arith.andi %vT_ptrs_256, %kT_254 : tensor<128x64xi1, #blocked1> loc(#loc932) + %kT_258 = ttg.async_copy_global_to_local %kT_ptrs_245, %kT_255 mask %vT_ptrs_257 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc926) + %kT_259 = ttg.async_commit_group tokens %kT_258 loc(#loc926) + %vT_260 = ttg.memdesc_index %vT_159[%vT_ptrs_251] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_261 = ttg.async_copy_global_to_local %vT_ptrs_246, %vT_260 mask %vT_ptrs_257 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc927) + %vT_262 = ttg.async_commit_group tokens %vT_261 loc(#loc927) + scf.yield %dq_226, %kT_ptrs_245, %offs_n2_248, %vT_ptrs_246, %offs_n2_228, %vT_ptrs_251, %vT_ptrs_208, %kT_199, %kT_259, %vT_201, %vT_262, %offset_242 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32 loc(#loc932) + } loc(#loc932) + %vT_ptrs_187 = ttng.warp_group_dot_wait %vT_ptrs_186#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> loc(#loc932) + %vT_ptrs_188 = ttg.async_wait {num = 0 : i32} loc(#loc932) + ttg.local_dealloc %vT_159 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc932) + ttg.local_dealloc %kT_158 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc932) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc476) + %dq_ptrs_189 = tt.addptr %dq_ptrs, %ptr_52 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc476) + %dq_ptrs_190 = tt.broadcast %dq_ptrs_189 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc477) + %dq_ptrs_191 = tt.addptr %dq_ptrs_190, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc477) + %dq = arith.mulf %vT_ptrs_187, %cst_5 : tensor<128x128xf32, #mma> loc(#loc478) + %11 = arith.cmpi slt, %ptr_56, %cst_0 : tensor<1x128xi32, #blocked> loc(#loc172) + %12 = tt.broadcast %11 : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc173) + %13 = arith.andi %q_62, %12 : tensor<128x128xi1, #blocked> loc(#loc173) + %14 = arith.truncf %dq : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc174) + %15 = ttg.convert_layout %14 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc174) + tt.store %dq_ptrs_191, %15, %13 : tensor<128x128x!tt.ptr, #blocked> loc(#loc174) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc479) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc480) + %offs_n1_37 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc480) + %offs_n1_38 = arith.addi %offs_n1, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc480) + %offs_n1_39 = arith.addi %offs_n1_37, %offs_k_36 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc480) + %ptr = tt.expand_dims %offs_n1_38 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc643) + %ptr_40 = tt.expand_dims %offs_n1_39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> loc(#loc643) + %ptr_41 = arith.muli %ptr, %cst : tensor<128x1xi32, #blocked> loc(#loc644) + %ptr_42 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc645) + %ptr_43 = tt.addptr %ptr_42, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc645) + %ptr_44 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc646) + %ptr_45 = tt.expand_dims %ptr_44 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc646) + %ptr_46 = tt.broadcast %ptr_43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc647) + %ptr_47 = tt.broadcast %ptr_45 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc647) + %ptr_48 = tt.addptr %ptr_46, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc647) + %k = tt.splat %ks1 : i32 -> tensor<128x1xi32, #blocked> loc(#loc648) + %k_49 = tt.splat %ks1 : i32 -> tensor<128x1xi32, #mma1> loc(#loc648) + %k_50 = arith.cmpi slt, %ptr, %k : tensor<128x1xi32, #blocked> loc(#loc648) + %k_51 = tt.broadcast %k_50 : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc649) + %k_52 = tt.load %ptr_48, %k_51, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc649) + %k_53 = ttg.local_alloc %k_52 : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc649) + %ptr_54 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc650) + %ptr_55 = tt.addptr %ptr_54, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc650) + %ptr_56 = tt.broadcast %ptr_55 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc651) + %ptr_57 = tt.addptr %ptr_56, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc651) + %v = tt.load %ptr_57, %k_51, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc652) + %v_58 = ttg.local_alloc %v : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc652) + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc483) + %q_adj1 = arith.muli %0, %off_zq : i32 loc(#loc484) + %do_adj1 = arith.muli %7, %off_zq : i32 loc(#loc485) + %off_chz1 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc486) + %sparse_q_idx_offset = arith.muli %pid, %ks6 : i32 loc(#loc487) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset : !tt.ptr, i32 loc(#loc488) + %q_start = tt.load %q_indices, %true : !tt.ptr loc(#loc489) + %q_start_59 = arith.muli %q_start, %c128_i32 : i32 loc(#loc490) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc491) + %sparse_q_num_blocks_60 = tt.load %sparse_q_num_blocks, %true : !tt.ptr loc(#loc492) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc493) + %offs_m1_61 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc493) + %offs_m1_62 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc493) + %offs_m1_63 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc494) + %offs_m1_64 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc494) + %offs_m1_65 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc494) + %offs_m1_66 = arith.addi %offs_m1_63, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc494) + %offs_m1_67 = arith.addi %offs_m1_64, %offs_m1_61 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc494) + %offs_m1_68 = arith.addi %offs_m1_65, %offs_m1_62 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc494) + %qT_ptrs = tt.expand_dims %offs_m1_67 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc653) + %qT_ptrs_69 = arith.muli %qT_ptrs, %cst_18 : tensor<1x64xi32, #blocked1> loc(#loc654) + %qT_ptrs_70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc655) + %qT_ptrs_71 = tt.expand_dims %qT_ptrs_70 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc655) + %qT_ptrs_72 = tt.broadcast %qT_ptrs_71 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc656) + %do_ptrs = tt.expand_dims %offs_m1_68 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc657) + %do_ptrs_73 = arith.muli %do_ptrs, %cst_17 : tensor<64x1xi32, #blocked> loc(#loc658) + %do_ptrs_74 = tt.broadcast %ptr_45 : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> loc(#loc659) + %hi = arith.muli %sparse_q_num_blocks_60, %c2_i32 : i32 loc(#loc660) + %hi_75 = arith.addi %ks0, %c63_i32 : i32 loc(#loc868) + %hi_76 = arith.divsi %hi_75, %c64_i32 : i32 loc(#loc869) + %hi_77 = arith.maxsi %hi_76, %c1_i32 : i32 loc(#loc662) + %hi_78 = arith.minsi %hi, %hi_77 : i32 loc(#loc663) + %qT = tt.splat %ks0 : i32 -> tensor<1x64xi32, #mma1> loc(#loc870) + %qT_79 = tt.splat %ks0 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc870) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc665) + %n = arith.remsi %ptr_40, %k_49 : tensor<128x1xi32, #mma1> loc(#loc871) + %tmp27 = tt.broadcast %n : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc667) + %tmp30 = arith.cmpi slt, %n, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc668) + %tmp31 = tt.broadcast %tmp30 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc669) + %tmp32 = arith.extui %tmp30 : tensor<128x1xi1, #mma1> to tensor<128x1xi32, #mma1> loc(#loc670) + %tmp32_80 = arith.cmpi eq, %tmp32, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc670) + %tmp33 = tt.broadcast %tmp32_80 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc671) + %tmp38 = arith.remsi %n, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc672) + %tmp38_81 = arith.cmpi ne, %tmp38, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc673) + %tmp38_82 = arith.divsi %n, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc674) + %tmp38_83 = arith.subi %tmp38_82, %cst_26 : tensor<128x1xi32, #mma1> loc(#loc675) + %tmp38_84 = arith.select %tmp38_81, %tmp38_83, %tmp38_82 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc676) + %tmp38_85 = arith.select %tmp30, %tmp38_84, %tmp38_82 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc677) + %tmp39 = tt.broadcast %tmp38_85 : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc678) + %do = tt.splat %ks0 : i32 -> tensor<64x1xi32, #blocked> loc(#loc872) + %q_indices_86 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset : !tt.ptr, i32 loc(#loc523) + %q_start_87 = tt.load %q_indices_86, %true : !tt.ptr loc(#loc524) + %q_start_88 = arith.muli %q_start_87, %c128_i32 : i32 loc(#loc525) + %sparse_q_num_blocks_89 = tt.addptr %arg_FULL_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc526) + %sparse_q_num_blocks_90 = tt.load %sparse_q_num_blocks_89, %true : !tt.ptr loc(#loc527) + %offs_m1_91 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc528) + %offs_m1_92 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc528) + %offs_m1_93 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc528) + %offs_m1_94 = arith.addi %offs_m1_91, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc528) + %offs_m1_95 = arith.addi %offs_m1_92, %offs_m1_61 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc528) + %offs_m1_96 = arith.addi %offs_m1_93, %offs_m1_62 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc528) + %qT_ptrs_97 = tt.expand_dims %offs_m1_95 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc680) + %qT_ptrs_98 = arith.muli %qT_ptrs_97, %cst_18 : tensor<1x64xi32, #blocked1> loc(#loc681) + %do_ptrs_99 = tt.expand_dims %offs_m1_96 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc682) + %do_ptrs_100 = arith.muli %do_ptrs_99, %cst_17 : tensor<64x1xi32, #blocked> loc(#loc683) + %hi_101 = arith.muli %sparse_q_num_blocks_90, %c2_i32 : i32 loc(#loc684) + %hi_102 = arith.minsi %hi_101, %hi_77 : i32 loc(#loc685) + ttng.fence_async_shared {bCluster = false} loc(#loc686) + %dk:2 = scf.for %dk_107 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg26 = %cst_6, %arg27 = %cst_6) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>) : i32 { + %off_hq1_108 = arith.addi %off_hq1, %dk_107 : i32 loc(#loc531) + %q_adj1_109 = arith.muli %off_hq1_108, %c128_i32 : i32 loc(#loc532) + %q_adj1_110 = arith.addi %q_adj1_109, %q_adj1 : i32 loc(#loc533) + %q_adj1_111 = arith.extsi %q_adj1_110 : i32 to i64 loc(#loc534) + %do_adj1_112 = arith.muli %8, %off_hq1_108 : i32 loc(#loc535) + %do_adj1_113 = arith.addi %do_adj1_112, %do_adj1 : i32 loc(#loc536) + %do_adj1_114 = arith.extsi %do_adj1_113 : i32 to i64 loc(#loc537) + %off_chz1_115 = arith.addi %off_chz1, %off_hq1_108 : i32 loc(#loc538) + %off_chz1_116 = arith.muli %off_chz1_115, %ks0 : i32 loc(#loc539) + %off_chz1_117 = arith.extsi %off_chz1_116 : i32 to i64 loc(#loc540) + %Q1 = tt.addptr %arg_Q, %q_adj1_111 : !tt.ptr, i64 loc(#loc541) + %DO1 = tt.addptr %arg_DO, %do_adj1_114 : !tt.ptr, i64 loc(#loc542) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_117 : !tt.ptr, i64 loc(#loc543) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_117 : !tt.ptr, i64 loc(#loc544) + %qT_ptrs_118 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc688) + %qT_ptrs_119 = tt.addptr %qT_ptrs_118, %qT_ptrs_69 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc688) + %qT_ptrs_120 = tt.broadcast %qT_ptrs_119 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc656) + %qT_ptrs_121 = tt.addptr %qT_ptrs_120, %qT_ptrs_72 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc656) + %do_ptrs_122 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> loc(#loc689) + %do_ptrs_123 = tt.addptr %do_ptrs_122, %do_ptrs_73 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc689) + %do_ptrs_124 = tt.broadcast %do_ptrs_123 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc659) + %do_ptrs_125 = tt.addptr %do_ptrs_124, %do_ptrs_74 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc659) + %lse_126 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc690) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc691) + %qT_127 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc873) + %lse_128 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc692) + %do_129 = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc874) + %Di_130 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc693) + %do_ptrs_131 = arith.cmpi sgt, %hi_78, %c0_i32 : i32 loc(#loc934) + %qT_132 = arith.cmpi slt, %qT_ptrs, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc870) + %qT_133 = tt.broadcast %qT_132 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc873) + %qT_134 = ttg.memdesc_index %qT_127[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %do_ptrs_135 = tt.splat %do_ptrs_131 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc934) + %do_ptrs_136 = arith.andi %do_ptrs_135, %qT_133 : tensor<128x64xi1, #blocked1> loc(#loc934) + %qT_137 = ttg.async_copy_global_to_local %qT_ptrs_121, %qT_134 mask %do_ptrs_136 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %qT_138 = ttg.async_commit_group tokens %qT_137 loc(#loc873) + %lse_139 = arith.cmpi slt, %offs_m1_66, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc665) + %lse_140 = tt.addptr %lse_126, %offs_m1_66 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc690) + %lse_141 = ttg.memdesc_index %lse_128[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %do_ptrs_142 = tt.splat %do_ptrs_131 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %do_ptrs_143 = arith.andi %do_ptrs_142, %lse_139 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %lse_144 = ttg.async_copy_global_to_local %lse_140, %lse_141 mask %do_ptrs_143 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %lse_145 = ttg.async_commit_group tokens %lse_144 loc(#loc692) + %do_146 = arith.cmpi slt, %do_ptrs, %do : tensor<64x1xi32, #blocked> loc(#loc872) + %do_147 = tt.broadcast %do_146 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc874) + %do_148 = ttg.memdesc_index %do_129[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_ptrs_149 = tt.splat %do_ptrs_131 : i1 -> tensor<64x128xi1, #blocked> loc(#loc934) + %do_ptrs_150 = arith.andi %do_ptrs_149, %do_147 : tensor<64x128xi1, #blocked> loc(#loc934) + %do_151 = ttg.async_copy_global_to_local %do_ptrs_125, %do_148 mask %do_ptrs_150 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_152 = ttg.async_commit_group tokens %do_151 loc(#loc874) + %Di_153 = tt.addptr %Di, %offs_m1_66 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc691) + %Di_154 = ttg.memdesc_index %Di_130[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_155 = ttg.async_copy_global_to_local %Di_153, %Di_154 mask %do_ptrs_143 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_156 = ttg.async_commit_group tokens %Di_155 loc(#loc693) + %do_ptrs_157 = arith.cmpi sgt, %hi_78, %c1_i32 : i32 loc(#loc934) + %qT_ptrs_158 = tt.addptr %qT_ptrs_121, %cst_13 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc695) + %do_ptrs_159 = tt.addptr %do_ptrs_125, %cst_14 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc696) + %offs_m1_160 = arith.addi %offs_m1_66, %cst_15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc697) + %offs_m1_161 = arith.addi %offs_m1_67, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc697) + %offs_m1_162 = arith.addi %offs_m1_68, %cst_16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc697) + %qT_163 = tt.expand_dims %offs_m1_161 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc876) + %qT_164 = arith.cmpi slt, %qT_163, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc870) + %qT_165 = tt.broadcast %qT_164 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc873) + %qT_166 = ttg.memdesc_index %qT_127[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %do_ptrs_167 = tt.splat %do_ptrs_157 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc934) + %do_ptrs_168 = arith.andi %do_ptrs_167, %qT_165 : tensor<128x64xi1, #blocked1> loc(#loc934) + %qT_169 = ttg.async_copy_global_to_local %qT_ptrs_158, %qT_166 mask %do_ptrs_168 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %qT_170 = ttg.async_commit_group tokens %qT_169 loc(#loc873) + %lse_171 = arith.cmpi slt, %offs_m1_160, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc665) + %lse_172 = tt.addptr %lse_126, %offs_m1_160 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc690) + %lse_173 = ttg.memdesc_index %lse_128[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %do_ptrs_174 = tt.splat %do_ptrs_157 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %do_ptrs_175 = arith.andi %do_ptrs_174, %lse_171 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %lse_176 = ttg.async_copy_global_to_local %lse_172, %lse_173 mask %do_ptrs_175 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %lse_177 = ttg.async_commit_group tokens %lse_176 loc(#loc692) + %do_178 = tt.expand_dims %offs_m1_162 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc877) + %do_179 = arith.cmpi slt, %do_178, %do : tensor<64x1xi32, #blocked> loc(#loc872) + %do_180 = tt.broadcast %do_179 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc874) + %do_181 = ttg.memdesc_index %do_129[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_ptrs_182 = tt.splat %do_ptrs_157 : i1 -> tensor<64x128xi1, #blocked> loc(#loc934) + %do_ptrs_183 = arith.andi %do_ptrs_182, %do_180 : tensor<64x128xi1, #blocked> loc(#loc934) + %do_184 = ttg.async_copy_global_to_local %do_ptrs_159, %do_181 mask %do_ptrs_183 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_185 = ttg.async_commit_group tokens %do_184 loc(#loc874) + %Di_186 = tt.addptr %Di, %offs_m1_160 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc691) + %Di_187 = ttg.memdesc_index %Di_130[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_188 = ttg.async_copy_global_to_local %Di_186, %Di_187 mask %do_ptrs_175 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_189 = ttg.async_commit_group tokens %Di_188 loc(#loc693) + %do_ptrs_190:22 = scf.for %do_ptrs_265 = %c0_i32 to %hi_78 step %c1_i32 iter_args(%arg29 = %arg27, %arg30 = %arg26, %qT_ptrs_266 = %qT_ptrs_158, %offs_m1_267 = %offs_m1_161, %do_ptrs_268 = %do_ptrs_159, %offs_m1_269 = %offs_m1_162, %offs_m1_270 = %offs_m1_160, %arg36 = %c1_i32, %arg37 = %c-1_i32, %arg38 = %c1_i32, %arg39 = %c-1_i32, %offs_m1_271 = %offs_m1_66, %qT_272 = %qT_138, %qT_273 = %qT_170, %lse_274 = %lse_145, %lse_275 = %lse_177, %do_276 = %do_152, %do_277 = %do_185, %Di_278 = %Di_156, %Di_279 = %Di_189, %arg49 = %c64_i32, %offs_m1_280 = %offs_m1_66) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) : i32 { + %do_ptrs_281 = arith.subi %hi_78, %c2_i32 : i32 loc(#loc934) + %do_ptrs_282 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_281 : i32 loc(#loc934) + %do_ptrs_283 = arith.subi %hi_78, %c1_i32 : i32 loc(#loc934) + %do_ptrs_284 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_283 : i32 loc(#loc934) + %do_ptrs_285 = arith.addi %arg39, %c1_i32 : i32 loc(#loc934) + %do_ptrs_286 = arith.cmpi sge, %do_ptrs_285, %c2_i32 : i32 loc(#loc934) + %do_ptrs_287 = arith.select %do_ptrs_286, %c0_i32, %do_ptrs_285 : i32 loc(#loc934) + %do_ptrs_288 = arith.addi %arg37, %c1_i32 : i32 loc(#loc934) + %do_ptrs_289 = arith.cmpi sge, %do_ptrs_288, %c3_i32 : i32 loc(#loc934) + %do_ptrs_290 = arith.select %do_ptrs_289, %c0_i32, %do_ptrs_288 : i32 loc(#loc934) + %qT_291 = tt.expand_dims %offs_m1_271 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc876) + %qT_292 = tt.expand_dims %offs_m1_280 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc876) + %qT_293 = arith.cmpi slt, %qT_291, %qT : tensor<1x64xi32, #mma1> loc(#loc870) + %qT_294 = tt.broadcast %qT_293 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc873) + %qT_295 = ttg.async_wait %qT_272, %lse_274, %do_276, %Di_278 {num = 4 : i32} loc(#loc873) + %qT_296 = ttg.memdesc_index %qT_127[%do_ptrs_290] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %dk_297 = ttg.memdesc_trans %qT_296 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc698) + %lse_298 = ttg.memdesc_index %lse_128[%do_ptrs_287] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %lse_299 = ttg.local_load %lse_298 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc692) + %lse_300 = arith.cmpf oeq, %lse_299, %cst_23 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc699) + %lse_301 = arith.select %lse_300, %cst_24, %lse_299 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc700) + %qkT = ttng.warp_group_dot %k_53, %qT_296, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc686) + %qkT_302:3 = ttng.warp_group_dot_wait %qkT, %k_53, %qT_296 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc686) + %qkT_303 = arith.mulf %qkT_302#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc701) + %m = arith.remsi %qT_292, %qT : tensor<1x64xi32, #mma1> loc(#loc878) + %post_mod_scores = arith.select %qT_294, %qkT_303, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc703) + %tmp25 = arith.cmpi slt, %m, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc704) + %tmp27_304 = tt.broadcast %m : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc667) + %tmp27_305 = arith.cmpi sle, %tmp27, %tmp27_304 : tensor<128x64xi32, #mma1> loc(#loc667) + %tmp28 = tt.broadcast %tmp25 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc705) + %tmp28_306 = arith.andi %tmp28, %tmp27_305 : tensor<128x64xi1, #mma1> loc(#loc705) + %tmp29 = arith.cmpi sge, %m, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc706) + %tmp31_307 = tt.broadcast %tmp29 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc669) + %tmp31_308 = arith.andi %tmp31_307, %tmp31 : tensor<128x64xi1, #mma1> loc(#loc669) + %tmp33_309 = arith.andi %tmp31_307, %tmp33 : tensor<128x64xi1, #mma1> loc(#loc671) + %tmp36 = arith.remsi %m, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc707) + %tmp36_310 = arith.cmpi ne, %tmp36, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc708) + %tmp36_311 = arith.divsi %m, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc709) + %tmp36_312 = arith.subi %tmp36_311, %cst_25 : tensor<1x64xi32, #mma1> loc(#loc710) + %tmp36_313 = arith.select %tmp36_310, %tmp36_312, %tmp36_311 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc711) + %tmp36_314 = arith.select %tmp25, %tmp36_313, %tmp36_311 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc712) + %tmp39_315 = tt.broadcast %tmp36_314 : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc678) + %tmp39_316 = arith.cmpi eq, %tmp39_315, %tmp39 : tensor<128x64xi32, #mma1> loc(#loc678) + %tmp40 = arith.andi %tmp33_309, %tmp39_316 : tensor<128x64xi1, #mma1> loc(#loc713) + %tmp41 = arith.ori %tmp31_308, %tmp40 : tensor<128x64xi1, #mma1> loc(#loc714) + %tmp42 = arith.ori %tmp28_306, %tmp41 : tensor<128x64xi1, #mma1> loc(#loc715) + %post_mod_scores_317 = arith.select %tmp42, %post_mod_scores, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc716) + %post_mod_scores_318 = arith.mulf %post_mod_scores_317, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc717) + %pT = tt.expand_dims %lse_301 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc718) + %pT_319 = tt.broadcast %pT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc719) + %pT_320 = arith.subf %post_mod_scores_318, %pT_319 : tensor<128x64xf32, #mma1> loc(#loc719) + %pT_321 = math.exp2 %pT_320 : tensor<128x64xf32, #mma1> loc(#loc720) + %do_322 = ttg.memdesc_index %do_129[%do_ptrs_290] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %dpT = ttg.memdesc_trans %do_322 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc721) + %dv = arith.truncf %pT_321 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc722) + %dv_323 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc722) + %dv_324 = ttng.warp_group_dot %dv_323, %do_322, %arg30 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc723) + %Di_325 = ttg.memdesc_index %Di_130[%do_ptrs_287] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_326 = ttg.local_load %Di_325 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc693) + %dpT_327 = ttng.warp_group_dot %v_58, %dpT, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc724) + %dpT_328:3 = ttng.warp_group_dot_wait %dpT_327, %v_58, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc724) + %dsT = tt.expand_dims %Di_326 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc725) + %dsT_329 = tt.broadcast %dsT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc726) + %dsT_330 = arith.subf %dpT_328#0, %dsT_329 : tensor<128x64xf32, #mma1> loc(#loc726) + %dsT_331 = arith.mulf %pT_321, %dsT_330 : tensor<128x64xf32, #mma1> loc(#loc727) + %grad_scores = arith.select %qT_294, %dsT_331, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc728) + %dsT_332 = arith.select %tmp42, %grad_scores, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc729) + %dk_333 = arith.truncf %dsT_332 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc730) + %dk_334 = ttg.convert_layout %dk_333 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc730) + %dk_335 = ttng.warp_group_dot %dk_334, %dk_297, %arg29 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc731) + %offs_m1_336 = tt.splat %arg49 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc697) + %offs_m1_337 = arith.addi %offs_m1_280, %offs_m1_336 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc697) + %do_ptrs_338 = arith.addi %do_ptrs_265, %c1_i32 : i32 loc(#loc934) + %cur_block_idx = arith.divsi %do_ptrs_338, %c2_i32 : i32 loc(#loc879) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc880) + %cur_block_339 = tt.load %cur_block, %do_ptrs_284 evictionPolicy = evict_last : !tt.ptr loc(#loc881) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc882) + %next_block_340 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_60 : i32 loc(#loc883) + %next_block_341 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc884) + %do_ptrs_342 = arith.andi %do_ptrs_284, %next_block_340 : i1 loc(#loc934) + %next_block_343 = tt.load %next_block_341, %do_ptrs_342 evictionPolicy = evict_last : !tt.ptr loc(#loc885) + %needs_jump = arith.addi %do_ptrs_265, %c2_i32 : i32 loc(#loc886) + %needs_jump_344 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc887) + %needs_jump_345 = arith.cmpi eq, %needs_jump_344, %c0_i32 : i32 loc(#loc888) + %jump_to_block = arith.subi %next_block_343, %cur_block_339 : i32 loc(#loc889) + %jump_to_block_346 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc890) + %jump_to_block_347 = arith.subi %jump_to_block_346, %c64_i32 : i32 loc(#loc891) + %offset = arith.extui %needs_jump_345 : i1 to i32 loc(#loc892) + %offset_348 = arith.muli %jump_to_block_347, %offset : i32 loc(#loc892) + %offset_349 = arith.subi %c1_i32, %offset : i32 loc(#loc893) + %offset_350 = arith.muli %offset_349, %c64_i32 : i32 loc(#loc894) + %offset_351 = arith.addi %offset_348, %offset_350 : i32 loc(#loc895) + %qT_ptrs_352 = arith.muli %offset_351, %c4096_i32 : i32 loc(#loc733) + %qT_ptrs_353 = tt.splat %qT_ptrs_352 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc695) + %qT_ptrs_354 = tt.addptr %qT_ptrs_266, %qT_ptrs_353 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc695) + %do_ptrs_355 = arith.muli %offset_351, %c128_i32 : i32 loc(#loc734) + %do_ptrs_356 = tt.splat %do_ptrs_355 : i32 -> tensor<64x128xi32, #blocked> loc(#loc696) + %do_ptrs_357 = tt.addptr %do_ptrs_268, %do_ptrs_356 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc696) + %offs_m1_358 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc697) + %offs_m1_359 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc697) + %offs_m1_360 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc697) + %offs_m1_361 = arith.addi %offs_m1_270, %offs_m1_358 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc697) + %offs_m1_362 = arith.addi %offs_m1_267, %offs_m1_359 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc697) + %offs_m1_363 = arith.addi %offs_m1_269, %offs_m1_360 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc697) + %do_ptrs_364 = arith.addi %arg38, %c1_i32 : i32 loc(#loc934) + %do_ptrs_365 = arith.cmpi sge, %do_ptrs_364, %c2_i32 : i32 loc(#loc934) + %do_ptrs_366 = arith.select %do_ptrs_365, %c0_i32, %do_ptrs_364 : i32 loc(#loc934) + %do_ptrs_367 = arith.addi %arg36, %c1_i32 : i32 loc(#loc934) + %do_ptrs_368 = arith.cmpi sge, %do_ptrs_367, %c3_i32 : i32 loc(#loc934) + %do_ptrs_369 = arith.select %do_ptrs_368, %c0_i32, %do_ptrs_367 : i32 loc(#loc934) + %qT_370 = tt.expand_dims %offs_m1_362 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc876) + %qT_371 = arith.cmpi slt, %qT_370, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc870) + %qT_372 = tt.broadcast %qT_371 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc873) + %qT_373 = ttg.memdesc_index %qT_127[%do_ptrs_369] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %do_ptrs_374 = tt.splat %do_ptrs_282 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc934) + %do_ptrs_375 = arith.andi %do_ptrs_374, %qT_372 : tensor<128x64xi1, #blocked1> loc(#loc934) + %qT_376 = ttg.async_copy_global_to_local %qT_ptrs_354, %qT_373 mask %do_ptrs_375 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc873) + %qT_377 = ttg.async_commit_group tokens %qT_376 loc(#loc873) + %lse_378 = arith.cmpi slt, %offs_m1_361, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc665) + %lse_379 = tt.addptr %lse_126, %offs_m1_361 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc690) + %lse_380 = ttg.memdesc_index %lse_128[%do_ptrs_366] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %do_ptrs_381 = tt.splat %do_ptrs_282 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %do_ptrs_382 = arith.andi %do_ptrs_381, %lse_378 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + %lse_383 = ttg.async_copy_global_to_local %lse_379, %lse_380 mask %do_ptrs_382 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc692) + %lse_384 = ttg.async_commit_group tokens %lse_383 loc(#loc692) + %do_385 = tt.expand_dims %offs_m1_363 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc877) + %do_386 = arith.cmpi slt, %do_385, %do : tensor<64x1xi32, #blocked> loc(#loc872) + %do_387 = tt.broadcast %do_386 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc874) + %do_388 = ttg.memdesc_index %do_129[%do_ptrs_369] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_ptrs_389 = tt.splat %do_ptrs_282 : i1 -> tensor<64x128xi1, #blocked> loc(#loc934) + %do_ptrs_390 = arith.andi %do_ptrs_389, %do_387 : tensor<64x128xi1, #blocked> loc(#loc934) + %do_391 = ttg.async_copy_global_to_local %do_ptrs_357, %do_388 mask %do_ptrs_390 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc874) + %do_392 = ttg.async_commit_group tokens %do_391 loc(#loc874) + %Di_393 = tt.addptr %Di, %offs_m1_361 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc691) + %Di_394 = ttg.memdesc_index %Di_130[%do_ptrs_366] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_395 = ttg.async_copy_global_to_local %Di_393, %Di_394 mask %do_ptrs_382 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc693) + %Di_396 = ttg.async_commit_group tokens %Di_395 loc(#loc693) + scf.yield %dk_335, %dv_324, %qT_ptrs_354, %offs_m1_362, %do_ptrs_357, %offs_m1_363, %offs_m1_361, %do_ptrs_369, %do_ptrs_290, %do_ptrs_366, %do_ptrs_287, %offs_m1_270, %qT_273, %qT_377, %lse_275, %lse_384, %do_277, %do_392, %Di_279, %Di_396, %offset_351, %offs_m1_337 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc934) + } loc(#loc934) + %do_ptrs_191:2 = ttng.warp_group_dot_wait %do_ptrs_190#1, %do_ptrs_190#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc934) + %do_ptrs_192 = ttg.async_wait {num = 0 : i32} loc(#loc934) + ttg.local_dealloc %Di_130 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc934) + ttg.local_dealloc %do_129 : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc934) + ttg.local_dealloc %lse_128 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc934) + ttg.local_dealloc %qT_127 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc934) + %qT_ptrs_193 = tt.addptr %qT_ptrs_118, %qT_ptrs_98 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc735) + %qT_ptrs_194 = tt.broadcast %qT_ptrs_193 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc736) + %qT_ptrs_195 = tt.addptr %qT_ptrs_194, %qT_ptrs_72 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc736) + %do_ptrs_196 = tt.addptr %do_ptrs_122, %do_ptrs_100 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc737) + %do_ptrs_197 = tt.broadcast %do_ptrs_196 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc738) + %do_ptrs_198 = tt.addptr %do_ptrs_197, %do_ptrs_74 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc738) + %qT_199 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc896) + %lse_200 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc740) + %do_201 = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc897) + %Di_202 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc742) + %do_ptrs_203 = arith.cmpi sgt, %hi_102, %c0_i32 : i32 loc(#loc935) + %qT_204 = arith.cmpi slt, %qT_ptrs_97, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc898) + %qT_205 = tt.broadcast %qT_204 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc896) + %qT_206 = ttg.memdesc_index %qT_199[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %do_ptrs_207 = tt.splat %do_ptrs_203 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc935) + %do_ptrs_208 = arith.andi %do_ptrs_207, %qT_205 : tensor<128x64xi1, #blocked1> loc(#loc935) + %qT_209 = ttg.async_copy_global_to_local %qT_ptrs_195, %qT_206 mask %do_ptrs_208 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %qT_210 = ttg.async_commit_group tokens %qT_209 loc(#loc896) + %lse_211 = arith.cmpi slt, %offs_m1_94, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc743) + %lse_212 = tt.addptr %lse_126, %offs_m1_94 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc744) + %lse_213 = ttg.memdesc_index %lse_200[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %do_ptrs_214 = tt.splat %do_ptrs_203 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %do_ptrs_215 = arith.andi %do_ptrs_214, %lse_211 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %lse_216 = ttg.async_copy_global_to_local %lse_212, %lse_213 mask %do_ptrs_215 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %lse_217 = ttg.async_commit_group tokens %lse_216 loc(#loc740) + %do_218 = arith.cmpi slt, %do_ptrs_99, %do : tensor<64x1xi32, #blocked> loc(#loc899) + %do_219 = tt.broadcast %do_218 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc897) + %do_220 = ttg.memdesc_index %do_201[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_ptrs_221 = tt.splat %do_ptrs_203 : i1 -> tensor<64x128xi1, #blocked> loc(#loc935) + %do_ptrs_222 = arith.andi %do_ptrs_221, %do_219 : tensor<64x128xi1, #blocked> loc(#loc935) + %do_223 = ttg.async_copy_global_to_local %do_ptrs_198, %do_220 mask %do_ptrs_222 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_224 = ttg.async_commit_group tokens %do_223 loc(#loc897) + %Di_225 = tt.addptr %Di, %offs_m1_94 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc745) + %Di_226 = ttg.memdesc_index %Di_202[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_227 = ttg.async_copy_global_to_local %Di_225, %Di_226 mask %do_ptrs_215 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_228 = ttg.async_commit_group tokens %Di_227 loc(#loc742) + %do_ptrs_229 = arith.cmpi sgt, %hi_102, %c1_i32 : i32 loc(#loc935) + %qT_ptrs_230 = tt.addptr %qT_ptrs_195, %cst_13 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc746) + %do_ptrs_231 = tt.addptr %do_ptrs_198, %cst_14 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc747) + %offs_m1_232 = arith.addi %offs_m1_94, %cst_15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc748) + %offs_m1_233 = arith.addi %offs_m1_95, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc748) + %offs_m1_234 = arith.addi %offs_m1_96, %cst_16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc748) + %qT_235 = tt.expand_dims %offs_m1_233 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc900) + %qT_236 = arith.cmpi slt, %qT_235, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc898) + %qT_237 = tt.broadcast %qT_236 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc896) + %qT_238 = ttg.memdesc_index %qT_199[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %do_ptrs_239 = tt.splat %do_ptrs_229 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc935) + %do_ptrs_240 = arith.andi %do_ptrs_239, %qT_237 : tensor<128x64xi1, #blocked1> loc(#loc935) + %qT_241 = ttg.async_copy_global_to_local %qT_ptrs_230, %qT_238 mask %do_ptrs_240 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %qT_242 = ttg.async_commit_group tokens %qT_241 loc(#loc896) + %lse_243 = arith.cmpi slt, %offs_m1_232, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc743) + %lse_244 = tt.addptr %lse_126, %offs_m1_232 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc744) + %lse_245 = ttg.memdesc_index %lse_200[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %do_ptrs_246 = tt.splat %do_ptrs_229 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %do_ptrs_247 = arith.andi %do_ptrs_246, %lse_243 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %lse_248 = ttg.async_copy_global_to_local %lse_244, %lse_245 mask %do_ptrs_247 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %lse_249 = ttg.async_commit_group tokens %lse_248 loc(#loc740) + %do_250 = tt.expand_dims %offs_m1_234 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc901) + %do_251 = arith.cmpi slt, %do_250, %do : tensor<64x1xi32, #blocked> loc(#loc899) + %do_252 = tt.broadcast %do_251 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc897) + %do_253 = ttg.memdesc_index %do_201[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_ptrs_254 = tt.splat %do_ptrs_229 : i1 -> tensor<64x128xi1, #blocked> loc(#loc935) + %do_ptrs_255 = arith.andi %do_ptrs_254, %do_252 : tensor<64x128xi1, #blocked> loc(#loc935) + %do_256 = ttg.async_copy_global_to_local %do_ptrs_231, %do_253 mask %do_ptrs_255 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_257 = ttg.async_commit_group tokens %do_256 loc(#loc897) + %Di_258 = tt.addptr %Di, %offs_m1_232 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc745) + %Di_259 = ttg.memdesc_index %Di_202[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_260 = ttg.async_copy_global_to_local %Di_258, %Di_259 mask %do_ptrs_247 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_261 = ttg.async_commit_group tokens %Di_260 loc(#loc742) + %do_ptrs_262:20 = scf.for %do_ptrs_265 = %c0_i32 to %hi_102 step %c1_i32 iter_args(%do_ptrs_266 = %do_ptrs_191#1, %do_ptrs_267 = %do_ptrs_191#0, %qT_ptrs_268 = %qT_ptrs_230, %offs_m1_269 = %offs_m1_233, %do_ptrs_270 = %do_ptrs_231, %offs_m1_271 = %offs_m1_234, %offs_m1_272 = %offs_m1_232, %arg36 = %c1_i32, %arg37 = %c-1_i32, %arg38 = %c1_i32, %arg39 = %c-1_i32, %offs_m1_273 = %offs_m1_94, %qT_274 = %qT_210, %qT_275 = %qT_242, %lse_276 = %lse_217, %lse_277 = %lse_249, %do_278 = %do_224, %do_279 = %do_257, %Di_280 = %Di_228, %Di_281 = %Di_261) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { + %do_ptrs_282 = arith.subi %hi_102, %c2_i32 : i32 loc(#loc935) + %do_ptrs_283 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_282 : i32 loc(#loc935) + %do_ptrs_284 = arith.subi %hi_102, %c1_i32 : i32 loc(#loc935) + %do_ptrs_285 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_284 : i32 loc(#loc935) + %do_ptrs_286 = arith.addi %arg39, %c1_i32 : i32 loc(#loc935) + %do_ptrs_287 = arith.cmpi sge, %do_ptrs_286, %c2_i32 : i32 loc(#loc935) + %do_ptrs_288 = arith.select %do_ptrs_287, %c0_i32, %do_ptrs_286 : i32 loc(#loc935) + %do_ptrs_289 = arith.addi %arg37, %c1_i32 : i32 loc(#loc935) + %do_ptrs_290 = arith.cmpi sge, %do_ptrs_289, %c3_i32 : i32 loc(#loc935) + %do_ptrs_291 = arith.select %do_ptrs_290, %c0_i32, %do_ptrs_289 : i32 loc(#loc935) + %qT_292 = tt.expand_dims %offs_m1_273 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc900) + %qT_293 = arith.cmpi slt, %qT_292, %qT : tensor<1x64xi32, #mma1> loc(#loc898) + %qT_294 = tt.broadcast %qT_293 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc896) + %qT_295 = ttg.async_wait %qT_274, %lse_276, %do_278, %Di_280 {num = 4 : i32} loc(#loc896) + %qT_296 = ttg.memdesc_index %qT_199[%do_ptrs_291] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %dk_297 = ttg.memdesc_trans %qT_296 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc749) + %lse_298 = ttg.memdesc_index %lse_200[%do_ptrs_288] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %lse_299 = ttg.local_load %lse_298 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc740) + %lse_300 = arith.cmpf oeq, %lse_299, %cst_23 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc750) + %lse_301 = arith.select %lse_300, %cst_24, %lse_299 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc751) + %qkT = ttng.warp_group_dot %k_53, %qT_296, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc752) + %qkT_302:3 = ttng.warp_group_dot_wait %qkT, %k_53, %qT_296 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc752) + %qkT_303 = arith.mulf %qkT_302#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc753) + %post_mod_scores = arith.select %qT_294, %qkT_303, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc754) + %post_mod_scores_304 = arith.mulf %post_mod_scores, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc755) + %pT = tt.expand_dims %lse_301 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc756) + %pT_305 = tt.broadcast %pT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc757) + %pT_306 = arith.subf %post_mod_scores_304, %pT_305 : tensor<128x64xf32, #mma1> loc(#loc757) + %pT_307 = math.exp2 %pT_306 : tensor<128x64xf32, #mma1> loc(#loc758) + %do_308 = ttg.memdesc_index %do_201[%do_ptrs_291] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %dpT = ttg.memdesc_trans %do_308 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc759) + %dv = arith.truncf %pT_307 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc760) + %dv_309 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc760) + %dv_310 = ttng.warp_group_dot %dv_309, %do_308, %do_ptrs_267 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc761) + %Di_311 = ttg.memdesc_index %Di_202[%do_ptrs_288] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_312 = ttg.local_load %Di_311 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc742) + %dpT_313 = ttng.warp_group_dot %v_58, %dpT, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc762) + %dpT_314:3 = ttng.warp_group_dot_wait %dpT_313, %v_58, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc762) + %dsT = tt.expand_dims %Di_312 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc763) + %dsT_315 = tt.broadcast %dsT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc764) + %dsT_316 = arith.subf %dpT_314#0, %dsT_315 : tensor<128x64xf32, #mma1> loc(#loc764) + %dsT_317 = arith.mulf %pT_307, %dsT_316 : tensor<128x64xf32, #mma1> loc(#loc765) + %grad_scores = arith.select %qT_294, %dsT_317, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc766) + %dk_318 = arith.truncf %grad_scores : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc767) + %dk_319 = ttg.convert_layout %dk_318 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc767) + %dk_320 = ttng.warp_group_dot %dk_319, %dk_297, %do_ptrs_266 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc768) + %do_ptrs_321 = arith.addi %do_ptrs_265, %c1_i32 : i32 loc(#loc935) + %cur_block_idx = arith.divsi %do_ptrs_321, %c2_i32 : i32 loc(#loc902) + %cur_block = tt.addptr %q_indices_86, %cur_block_idx : !tt.ptr, i32 loc(#loc903) + %cur_block_322 = tt.load %cur_block, %do_ptrs_285 evictionPolicy = evict_last : !tt.ptr loc(#loc904) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc905) + %next_block_323 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_90 : i32 loc(#loc906) + %next_block_324 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc907) + %do_ptrs_325 = arith.andi %do_ptrs_285, %next_block_323 : i1 loc(#loc935) + %next_block_326 = tt.load %next_block_324, %do_ptrs_325 evictionPolicy = evict_last : !tt.ptr loc(#loc908) + %needs_jump = arith.addi %do_ptrs_265, %c2_i32 : i32 loc(#loc909) + %needs_jump_327 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc910) + %needs_jump_328 = arith.cmpi eq, %needs_jump_327, %c0_i32 : i32 loc(#loc911) + %jump_to_block = arith.subi %next_block_326, %cur_block_322 : i32 loc(#loc912) + %jump_to_block_329 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc913) + %jump_to_block_330 = arith.subi %jump_to_block_329, %c64_i32 : i32 loc(#loc914) + %offset = arith.extui %needs_jump_328 : i1 to i32 loc(#loc915) + %offset_331 = arith.muli %jump_to_block_330, %offset : i32 loc(#loc915) + %offset_332 = arith.subi %c1_i32, %offset : i32 loc(#loc916) + %offset_333 = arith.muli %offset_332, %c64_i32 : i32 loc(#loc917) + %offset_334 = arith.addi %offset_331, %offset_333 : i32 loc(#loc918) + %qT_ptrs_335 = arith.muli %offset_334, %c4096_i32 : i32 loc(#loc770) + %qT_ptrs_336 = tt.splat %qT_ptrs_335 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc746) + %qT_ptrs_337 = tt.addptr %qT_ptrs_268, %qT_ptrs_336 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc746) + %do_ptrs_338 = arith.muli %offset_334, %c128_i32 : i32 loc(#loc771) + %do_ptrs_339 = tt.splat %do_ptrs_338 : i32 -> tensor<64x128xi32, #blocked> loc(#loc747) + %do_ptrs_340 = tt.addptr %do_ptrs_270, %do_ptrs_339 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc747) + %offs_m1_341 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc748) + %offs_m1_342 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc748) + %offs_m1_343 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc748) + %offs_m1_344 = arith.addi %offs_m1_272, %offs_m1_341 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc748) + %offs_m1_345 = arith.addi %offs_m1_269, %offs_m1_342 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc748) + %offs_m1_346 = arith.addi %offs_m1_271, %offs_m1_343 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc748) + %do_ptrs_347 = arith.addi %arg38, %c1_i32 : i32 loc(#loc935) + %do_ptrs_348 = arith.cmpi sge, %do_ptrs_347, %c2_i32 : i32 loc(#loc935) + %do_ptrs_349 = arith.select %do_ptrs_348, %c0_i32, %do_ptrs_347 : i32 loc(#loc935) + %do_ptrs_350 = arith.addi %arg36, %c1_i32 : i32 loc(#loc935) + %do_ptrs_351 = arith.cmpi sge, %do_ptrs_350, %c3_i32 : i32 loc(#loc935) + %do_ptrs_352 = arith.select %do_ptrs_351, %c0_i32, %do_ptrs_350 : i32 loc(#loc935) + %qT_353 = tt.expand_dims %offs_m1_345 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc900) + %qT_354 = arith.cmpi slt, %qT_353, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc898) + %qT_355 = tt.broadcast %qT_354 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc896) + %qT_356 = ttg.memdesc_index %qT_199[%do_ptrs_352] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %do_ptrs_357 = tt.splat %do_ptrs_283 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc935) + %do_ptrs_358 = arith.andi %do_ptrs_357, %qT_355 : tensor<128x64xi1, #blocked1> loc(#loc935) + %qT_359 = ttg.async_copy_global_to_local %qT_ptrs_337, %qT_356 mask %do_ptrs_358 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc896) + %qT_360 = ttg.async_commit_group tokens %qT_359 loc(#loc896) + %lse_361 = arith.cmpi slt, %offs_m1_344, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc743) + %lse_362 = tt.addptr %lse_126, %offs_m1_344 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc744) + %lse_363 = ttg.memdesc_index %lse_200[%do_ptrs_349] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %do_ptrs_364 = tt.splat %do_ptrs_283 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %do_ptrs_365 = arith.andi %do_ptrs_364, %lse_361 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc935) + %lse_366 = ttg.async_copy_global_to_local %lse_362, %lse_363 mask %do_ptrs_365 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc740) + %lse_367 = ttg.async_commit_group tokens %lse_366 loc(#loc740) + %do_368 = tt.expand_dims %offs_m1_346 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc901) + %do_369 = arith.cmpi slt, %do_368, %do : tensor<64x1xi32, #blocked> loc(#loc899) + %do_370 = tt.broadcast %do_369 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc897) + %do_371 = ttg.memdesc_index %do_201[%do_ptrs_352] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_ptrs_372 = tt.splat %do_ptrs_283 : i1 -> tensor<64x128xi1, #blocked> loc(#loc935) + %do_ptrs_373 = arith.andi %do_ptrs_372, %do_370 : tensor<64x128xi1, #blocked> loc(#loc935) + %do_374 = ttg.async_copy_global_to_local %do_ptrs_340, %do_371 mask %do_ptrs_373 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc897) + %do_375 = ttg.async_commit_group tokens %do_374 loc(#loc897) + %Di_376 = tt.addptr %Di, %offs_m1_344 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc745) + %Di_377 = ttg.memdesc_index %Di_202[%do_ptrs_349] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_378 = ttg.async_copy_global_to_local %Di_376, %Di_377 mask %do_ptrs_365 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc742) + %Di_379 = ttg.async_commit_group tokens %Di_378 loc(#loc742) + scf.yield %dk_320, %dv_310, %qT_ptrs_337, %offs_m1_345, %do_ptrs_340, %offs_m1_346, %offs_m1_344, %do_ptrs_352, %do_ptrs_291, %do_ptrs_349, %do_ptrs_288, %offs_m1_272, %qT_275, %qT_360, %lse_277, %lse_367, %do_279, %do_375, %Di_281, %Di_379 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc935) + } loc(#loc935) + %do_ptrs_263:2 = ttng.warp_group_dot_wait %do_ptrs_262#1, %do_ptrs_262#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc935) + %do_ptrs_264 = ttg.async_wait {num = 0 : i32} loc(#loc935) + ttg.local_dealloc %Di_202 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc935) + ttg.local_dealloc %do_201 : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc935) + ttg.local_dealloc %lse_200 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc935) + ttg.local_dealloc %qT_199 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc935) + scf.yield %do_ptrs_263#0, %do_ptrs_263#1 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc291) + } loc(#loc687) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc593) + %dv_ptrs_103 = tt.addptr %dv_ptrs, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc593) + %dv_ptrs_104 = tt.broadcast %dv_ptrs_103 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc594) + %dv_ptrs_105 = tt.addptr %dv_ptrs_104, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc594) + %11 = arith.cmpi slt, %ptr_45, %cst_0 : tensor<1x128xi32, #blocked> loc(#loc294) + %12 = tt.broadcast %11 : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc295) + %13 = arith.andi %k_51, %12 : tensor<128x128xi1, #blocked> loc(#loc295) + %14 = arith.truncf %dk#0 : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc296) + %15 = ttg.convert_layout %14 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc296) + tt.store %dv_ptrs_105, %15, %13 : tensor<128x128x!tt.ptr, #blocked> loc(#loc296) + %dk_106 = arith.mulf %dk#1, %cst_5 : tensor<128x128xf32, #mma> loc(#loc595) + %16 = tt.splat %k_adj : i32 -> tensor<1x128xi32, #blocked> loc(#loc298) + %17 = arith.addi %ptr_45, %16 : tensor<1x128xi32, #blocked> loc(#loc298) + %18 = tt.broadcast %17 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc299) + %19 = tt.broadcast %ptr_41 : tensor<128x1xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc299) + %20 = arith.addi %18, %19 : tensor<128x128xi32, #blocked> loc(#loc299) + %21 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> loc(#loc300) + %22 = tt.addptr %21, %20 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc300) + %23 = arith.truncf %dk_106 : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc301) + %24 = ttg.convert_layout %23 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc301) + tt.store %22, %24, %k_51 : tensor<128x128x!tt.ptr, #blocked> loc(#loc301) + } loc(#loc28) + tt.return loc(#loc302) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":94:54) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:74) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:66) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:100) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:91) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:82) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:59) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:111) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":100:58) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":111:24) +#loc12 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":112:36) +#loc14 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":113:34) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":115:27) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":116:28) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:25) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:59) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:50) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:37) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:61) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":131:9) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":132:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":133:10) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":136:26) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:14) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:7) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":140:24) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:29) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:54) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:44) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":145:35) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":155:83) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:30) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:52) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:40) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:63) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:32) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:55) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:42) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:66) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:30) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:35) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:46) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:56) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":163:17) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":164:19) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":167:19) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":168:21) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":169:25) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":174:36) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":175:29) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:27) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":178:107) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:38) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:20) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:56) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:49) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:52) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:23) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":179:111) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:58) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:34) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:25) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:33) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:26) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:30) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:50) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":191:18) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":195:30) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:27) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:41) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:53) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:39) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:42) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:29) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:26) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":207:12) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:37) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:18) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:56) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:49) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:18) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:49) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:43) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:90) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:101) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:63) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:52) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":458:105) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":405:12) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":762:21) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":467:46) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":481:22) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":483:23) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":484:22) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":485:23) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":487:22) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:70) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:79) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:91) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:99) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:102) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:119) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":495:25) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:39) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:22) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:19) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:23) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":510:104) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":397:28) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:19) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":415:19) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":417:19) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:41) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":459:19) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:30) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":461:14) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":464:46) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":476:79) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":486:22) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":488:24) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":489:23) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:70) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:79) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:91) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:99) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:102) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:119) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":496:24) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":497:23) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":498:23) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":503:69) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":506:27) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:21) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":512:20) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:14) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":520:71) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":531:43) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":533:15) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:21) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":752:33) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":411:64) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:38) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:24) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:109) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:113) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:55) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:25) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:30) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:35) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:60) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:34) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:48) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:63) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:29) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:47) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:61) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:42) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:28) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":214:39) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:31) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:45) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:62) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:43) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":218:33) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":226:16) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:24) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:56) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":232:14) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:87) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:69) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:30) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":252:25) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":253:29) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":256:107) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":257:107) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:32) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:56) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:59) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:34) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":282:81) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":286:32) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:30) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:43) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:55) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:42) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:45) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:32) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:26) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":298:16) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:37) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:56) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:49) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:27) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:38) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:51) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:42) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:87) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:98) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:61) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":651:105) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":600:12) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:52) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":665:46) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":681:25) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":684:24) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":685:24) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":686:25) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":687:24) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:70) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:79) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:91) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:99) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:102) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:119) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":693:25) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":705:99) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":306:41) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:34) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:47) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:64) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:46) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":310:36) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":318:20) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":658:20) +#loc228 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":262:30) +#loc229 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:51) +#loc230 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:34) +#loc231 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:44) +#loc232 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:67) +#loc233 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:36) +#loc234 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:46) +#loc235 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:70) +#loc236 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:39) +#loc237 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:50) +#loc238 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:60) +#loc239 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":271:21) +#loc240 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":272:23) +#loc241 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":275:25) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":276:29) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:18) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:19) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:28) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:29) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:22) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:21) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":592:28) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:19) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:19) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":610:19) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:41) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:52) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:26) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:46) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":660:15) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":662:46) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":674:78) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":679:24) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":682:24) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":683:25) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:70) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:79) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:91) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:99) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:102) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:119) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":694:24) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":695:24) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":696:24) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":700:69) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":703:27) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:44) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:40) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:22) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:29) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:24) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:43) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:20) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:25) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:22) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:16) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":723:70) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":737:45) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:24) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:43) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":605:62) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:28) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:28) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":303:12) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:23) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:55) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:71) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:61) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:30) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":334:14) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:55) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:69) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:29) +#loc301 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:99) +#loc302 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:4) +#loc328 = loc("pid"(#loc11)) +#loc329 = loc("NUM_KV_BLOCKS"(#loc13)) +#loc330 = loc("NUM_Q_BLOCKS"(#loc15)) +#loc331 = loc("off_zq"(#loc16)) +#loc332 = loc("off_hkv"(#loc17)) +#loc333 = loc("k_adj"(#loc18)) +#loc334 = loc("k_adj"(#loc19)) +#loc335 = loc("dv_adj"(#loc20)) +#loc336 = loc("dv_adj"(#loc21)) +#loc337 = loc("dv_adj"(#loc22)) +#loc338 = loc("K"(#loc23)) +#loc339 = loc("V"(#loc24)) +#loc340 = loc("DV"(#loc25)) +#loc341 = loc("offs_k"(#loc26)) +#loc342 = loc("off_pid"(#loc29)) +#loc343 = loc("off_hq2"(#loc30)) +#loc344 = loc("off_hq2"(#loc31)) +#loc345 = loc("off_hq2"(#loc32)) +#loc346 = loc("start_m2_block"(#loc33)) +#loc347 = loc("sparse_kv_idx_offset"(#loc34)) +#loc348 = loc("q_adj2"(#loc35)) +#loc349 = loc("q_adj2"(#loc36)) +#loc350 = loc("q_adj2"(#loc37)) +#loc351 = loc("q_adj2"(#loc38)) +#loc352 = loc("do_adj2"(#loc39)) +#loc353 = loc("do_adj2"(#loc40)) +#loc354 = loc("do_adj2"(#loc41)) +#loc355 = loc("do_adj2"(#loc42)) +#loc356 = loc("off_chz2"(#loc43)) +#loc357 = loc("off_chz2"(#loc44)) +#loc358 = loc("off_chz2"(#loc45)) +#loc359 = loc("off_chz2"(#loc46)) +#loc360 = loc("Q2"(#loc47)) +#loc361 = loc("DO2"(#loc48)) +#loc362 = loc("DQ2"(#loc49)) +#loc363 = loc("LSE2"(#loc50)) +#loc364 = loc("DELTA2"(#loc51)) +#loc365 = loc("start_m2"(#loc52)) +#loc366 = loc("offs_m2"(#loc53)) +#loc367 = loc("ptr"(#loc54)) +#loc368 = loc("q"(#loc55)) +#loc369 = loc("ptr"(#loc56)) +#loc370 = loc("ptr"(#loc57)) +#loc371 = loc("ptr"(#loc58)) +#loc372 = loc("ptr"(#loc59)) +#loc373 = loc("do"(#loc62)) +#loc374 = loc("Di"(#loc63)) +#loc375 = loc("Di"(#loc64)) +#loc376 = loc("Di"(#loc65)) +#loc377 = loc("lse"(#loc66)) +#loc378 = loc("lse"(#loc67)) +#loc379 = loc("lse"(#loc68)) +#loc380 = loc("lse"(#loc69)) +#loc381 = loc("lse"(#loc70)) +#loc382 = loc("kv_indices"(#loc71)) +#loc383 = loc("kv_start"(#loc72)) +#loc384 = loc("kv_start"(#loc73)) +#loc385 = loc("sparse_kv_num_blocks"(#loc74)) +#loc386 = loc("sparse_kv_num_blocks"(#loc75)) +#loc387 = loc("offs_n2"(#loc76)) +#loc388 = loc("offs_n2"(#loc77)) +#loc389 = loc("kT_ptrs"(#loc78)) +#loc390 = loc("dq"(#loc79)) +#loc391 = loc("kT_ptrs"(#loc80)) +#loc392 = loc("kT_ptrs"(#loc81)) +#loc393 = loc("kT_ptrs"(#loc82)) +#loc394 = loc("kT_ptrs"(#loc83)) +#loc395 = loc("vT_ptrs"(#loc84)) +#loc396 = loc("vT_ptrs"(#loc85)) +#loc397 = loc("hi"(#loc86)) +#loc398 = loc("hi"(#loc87)) +#loc399 = loc("hi"(#loc88)) +#loc400 = loc("hi"(#loc89)) +#loc401 = loc("kT"(#loc91)) +#loc402 = loc("dq"(#loc92)) +#loc403 = loc("m"(#loc94)) +#loc404 = loc("tmp3"(#loc95)) +#loc405 = loc("tmp5"(#loc96)) +#loc406 = loc("tmp6"(#loc97)) +#loc407 = loc("tmp7"(#loc98)) +#loc408 = loc("tmp9"(#loc99)) +#loc409 = loc("tmp14"(#loc100)) +#loc410 = loc("tmp14"(#loc101)) +#loc411 = loc("tmp14"(#loc102)) +#loc412 = loc("tmp14"(#loc103)) +#loc413 = loc("tmp14"(#loc104)) +#loc414 = loc("tmp14"(#loc105)) +#loc415 = loc("tmp17"(#loc106)) +#loc416 = loc("p"(#loc107)) +#loc417 = loc("ds"(#loc108)) +#loc418 = loc("ds"(#loc109)) +#loc419 = loc("vT"(#loc111)) +#loc420 = loc("dq"(#loc112)) +#loc421 = loc("kT_ptrs"(#loc113)) +#loc422 = loc("vT_ptrs"(#loc114)) +#loc423 = loc("offs_n2"(#loc115)) +#loc424 = loc("qk"(#loc117)) +#loc425 = loc("dq"(#loc118)) +#loc426 = loc("qk"(#loc119)) +#loc427 = loc("n"(#loc120)) +#loc428 = loc("post_mod_scores"(#loc121)) +#loc429 = loc("tmp8"(#loc122)) +#loc430 = loc("tmp10"(#loc123)) +#loc431 = loc("tmp11"(#loc124)) +#loc432 = loc("tmp16"(#loc125)) +#loc433 = loc("tmp16"(#loc126)) +#loc434 = loc("tmp16"(#loc127)) +#loc435 = loc("tmp16"(#loc128)) +#loc436 = loc("tmp16"(#loc129)) +#loc437 = loc("tmp16"(#loc130)) +#loc438 = loc("tmp18"(#loc131)) +#loc439 = loc("tmp19"(#loc132)) +#loc440 = loc("tmp20"(#loc133)) +#loc441 = loc("post_mod_scores"(#loc134)) +#loc442 = loc("post_mod_scores"(#loc135)) +#loc443 = loc("p"(#loc136)) +#loc444 = loc("dp"(#loc137)) +#loc445 = loc("ds"(#loc138)) +#loc446 = loc("grad_scores"(#loc139)) +#loc447 = loc("ds"(#loc140)) +#loc448 = loc("ds"(#loc141)) +#loc449 = loc("dq"(#loc142)) +#loc450 = loc("cur_block_idx"(#loc143)) +#loc451 = loc("offset"(#loc144)) +#loc452 = loc("cur_block"(#loc145)) +#loc453 = loc("cur_block"(#loc146)) +#loc454 = loc("next_block"(#loc147)) +#loc455 = loc("next_block"(#loc148)) +#loc456 = loc("next_block"(#loc149)) +#loc457 = loc("next_block"(#loc150)) +#loc458 = loc("needs_jump"(#loc151)) +#loc459 = loc("needs_jump"(#loc152)) +#loc460 = loc("needs_jump"(#loc153)) +#loc461 = loc("jump_to_block"(#loc154)) +#loc462 = loc("jump_to_block"(#loc155)) +#loc463 = loc("jump_to_block"(#loc156)) +#loc464 = loc("offset"(#loc157)) +#loc465 = loc("offset"(#loc158)) +#loc466 = loc("offset"(#loc159)) +#loc467 = loc("offset"(#loc160)) +#loc468 = loc("kT_ptrs"(#loc161)) +#loc469 = loc("kv_indices"(#loc162)) +#loc470 = loc("kv_start"(#loc163)) +#loc471 = loc("kv_start"(#loc164)) +#loc472 = loc("sparse_kv_num_blocks"(#loc165)) +#loc473 = loc("sparse_kv_num_blocks"(#loc166)) +#loc474 = loc("offs_n2"(#loc167)) +#loc475 = loc("dq"(#loc168)) +#loc476 = loc("dq_ptrs"(#loc169)) +#loc477 = loc("dq_ptrs"(#loc170)) +#loc478 = loc("dq"(#loc171)) +#loc479 = loc("start_n1"(#loc175)) +#loc480 = loc("offs_n1"(#loc176)) +#loc481 = loc("k"(#loc177)) +#loc482 = loc("v"(#loc178)) +#loc483 = loc("off_hq1"(#loc179)) +#loc484 = loc("q_adj1"(#loc180)) +#loc485 = loc("do_adj1"(#loc181)) +#loc486 = loc("off_chz1"(#loc182)) +#loc487 = loc("sparse_q_idx_offset"(#loc183)) +#loc488 = loc("q_indices"(#loc184)) +#loc489 = loc("q_start"(#loc185)) +#loc490 = loc("q_start"(#loc186)) +#loc491 = loc("sparse_q_num_blocks"(#loc187)) +#loc492 = loc("sparse_q_num_blocks"(#loc188)) +#loc493 = loc("offs_m1"(#loc189)) +#loc494 = loc("offs_m1"(#loc190)) +#loc495 = loc("qT_ptrs"(#loc191)) +#loc496 = loc("qT_ptrs"(#loc193)) +#loc497 = loc("qT_ptrs"(#loc194)) +#loc498 = loc("qT_ptrs"(#loc195)) +#loc499 = loc("do_ptrs"(#loc196)) +#loc500 = loc("do_ptrs"(#loc197)) +#loc501 = loc("do_ptrs"(#loc198)) +#loc502 = loc("hi"(#loc199)) +#loc503 = loc("hi"(#loc200)) +#loc504 = loc("hi"(#loc201)) +#loc505 = loc("hi"(#loc202)) +#loc506 = loc("qT"(#loc203)) +#loc507 = loc(callsite(#loc204 at #loc192)) +#loc508 = loc("lse"(#loc205)) +#loc509 = loc("n"(#loc206)) +#loc510 = loc("tmp27"(#loc207)) +#loc511 = loc("tmp30"(#loc208)) +#loc512 = loc("tmp31"(#loc209)) +#loc513 = loc("tmp32"(#loc210)) +#loc514 = loc("tmp33"(#loc211)) +#loc515 = loc("tmp38"(#loc212)) +#loc516 = loc("tmp38"(#loc213)) +#loc517 = loc("tmp38"(#loc214)) +#loc518 = loc("tmp38"(#loc215)) +#loc519 = loc("tmp38"(#loc216)) +#loc520 = loc("tmp38"(#loc217)) +#loc521 = loc("tmp39"(#loc218)) +#loc522 = loc("do"(#loc219)) +#loc523 = loc("q_indices"(#loc220)) +#loc524 = loc("q_start"(#loc221)) +#loc525 = loc("q_start"(#loc222)) +#loc526 = loc("sparse_q_num_blocks"(#loc223)) +#loc527 = loc("sparse_q_num_blocks"(#loc224)) +#loc528 = loc("offs_m1"(#loc225)) +#loc529 = loc("qkT"(#loc227)) +#loc530 = loc("dv"(#loc228)) +#loc531 = loc("off_hq1"(#loc229)) +#loc532 = loc("q_adj1"(#loc230)) +#loc533 = loc("q_adj1"(#loc231)) +#loc534 = loc("q_adj1"(#loc232)) +#loc535 = loc("do_adj1"(#loc233)) +#loc536 = loc("do_adj1"(#loc234)) +#loc537 = loc("do_adj1"(#loc235)) +#loc538 = loc("off_chz1"(#loc236)) +#loc539 = loc("off_chz1"(#loc237)) +#loc540 = loc("off_chz1"(#loc238)) +#loc541 = loc("Q1"(#loc239)) +#loc542 = loc("DO1"(#loc240)) +#loc543 = loc("LSE1"(#loc241)) +#loc544 = loc("DELTA1"(#loc242)) +#loc545 = loc("qT_ptrs"(#loc243)) +#loc546 = loc("do_ptrs"(#loc244)) +#loc547 = loc("lse"(#loc245)) +#loc548 = loc("Di"(#loc246)) +#loc549 = loc("lse"(#loc247)) +#loc550 = loc("Di"(#loc248)) +#loc551 = loc("dk"(#loc249)) +#loc552 = loc("qT_ptrs"(#loc250)) +#loc553 = loc("do_ptrs"(#loc251)) +#loc554 = loc("offs_m1"(#loc252)) +#loc555 = loc("dk"(#loc254)) +#loc556 = loc("lse"(#loc255)) +#loc557 = loc("lse"(#loc256)) +#loc558 = loc("qkT"(#loc257)) +#loc559 = loc("m"(#loc258)) +#loc560 = loc("post_mod_scores"(#loc259)) +#loc561 = loc("tmp25"(#loc260)) +#loc562 = loc("tmp28"(#loc261)) +#loc563 = loc("tmp29"(#loc262)) +#loc564 = loc("tmp36"(#loc263)) +#loc565 = loc("tmp36"(#loc264)) +#loc566 = loc("tmp36"(#loc265)) +#loc567 = loc("tmp36"(#loc266)) +#loc568 = loc("tmp36"(#loc267)) +#loc569 = loc("tmp36"(#loc268)) +#loc570 = loc("tmp40"(#loc269)) +#loc571 = loc("tmp41"(#loc270)) +#loc572 = loc("tmp42"(#loc271)) +#loc573 = loc("post_mod_scores"(#loc272)) +#loc574 = loc("post_mod_scores"(#loc273)) +#loc575 = loc("pT"(#loc274)) +#loc576 = loc("pT"(#loc275)) +#loc577 = loc("pT"(#loc276)) +#loc578 = loc("dpT"(#loc277)) +#loc579 = loc("dv"(#loc278)) +#loc580 = loc("dv"(#loc279)) +#loc581 = loc("dpT"(#loc280)) +#loc582 = loc("dsT"(#loc281)) +#loc583 = loc("dsT"(#loc282)) +#loc584 = loc("dsT"(#loc283)) +#loc585 = loc("grad_scores"(#loc284)) +#loc586 = loc("dsT"(#loc285)) +#loc587 = loc("dk"(#loc286)) +#loc588 = loc("dk"(#loc287)) +#loc589 = loc("offset"(#loc288)) +#loc590 = loc("qT_ptrs"(#loc289)) +#loc591 = loc("do_ptrs"(#loc290)) +#loc592 = loc(callsite(#loc204 at #loc226)) +#loc593 = loc("dv_ptrs"(#loc292)) +#loc594 = loc("dv_ptrs"(#loc293)) +#loc595 = loc("dk"(#loc297)) +#loc596 = loc(callsite(#loc12 at #loc329)) +#loc597 = loc(callsite(#loc14 at #loc329)) +#loc598 = loc(callsite(#loc12 at #loc330)) +#loc599 = loc(callsite(#loc14 at #loc330)) +#loc600 = loc(callsite(#loc367 at #loc368)) +#loc601 = loc(callsite(#loc369 at #loc368)) +#loc602 = loc(callsite(#loc370 at #loc368)) +#loc603 = loc(callsite(#loc371 at #loc368)) +#loc604 = loc(callsite(#loc372 at #loc368)) +#loc605 = loc(callsite(#loc60 at #loc368)) +#loc606 = loc(callsite(#loc61 at #loc368)) +#loc607 = loc(callsite(#loc369 at #loc373)) +#loc608 = loc(callsite(#loc370 at #loc373)) +#loc609 = loc(callsite(#loc372 at #loc373)) +#loc610 = loc(callsite(#loc61 at #loc373)) +#loc611 = loc(callsite(#loc389 at #loc390)) +#loc612 = loc(callsite(#loc391 at #loc390)) +#loc613 = loc(callsite(#loc392 at #loc390)) +#loc614 = loc(callsite(#loc393 at #loc390)) +#loc615 = loc(callsite(#loc394 at #loc390)) +#loc616 = loc(callsite(#loc395 at #loc390)) +#loc617 = loc(callsite(#loc396 at #loc390)) +#loc618 = loc(callsite(#loc397 at #loc390)) +#loc619 = loc(callsite(#loc398 at #loc390)) +#loc620 = loc(callsite(#loc399 at #loc390)) +#loc621 = loc(callsite(#loc400 at #loc390)) +#loc622 = loc(callsite(#loc402 at #loc390)) +#loc623 = loc("offs_n2"(#loc420)) +#loc624 = loc(callsite(#loc421 at #loc390)) +#loc625 = loc(callsite(#loc422 at #loc390)) +#loc626 = loc(callsite(#loc423 at #loc390)) +#loc627 = loc(callsite(#loc451 at #loc390)) +#loc628 = loc(callsite(#loc468 at #loc390)) +#loc629 = loc(callsite(#loc389 at #loc475)) +#loc630 = loc(callsite(#loc391 at #loc475)) +#loc631 = loc(callsite(#loc392 at #loc475)) +#loc632 = loc(callsite(#loc394 at #loc475)) +#loc633 = loc(callsite(#loc395 at #loc475)) +#loc634 = loc(callsite(#loc396 at #loc475)) +#loc635 = loc(callsite(#loc397 at #loc475)) +#loc636 = loc(callsite(#loc400 at #loc475)) +#loc637 = loc(callsite(#loc402 at #loc475)) +#loc638 = loc(callsite(#loc421 at #loc475)) +#loc639 = loc(callsite(#loc422 at #loc475)) +#loc640 = loc(callsite(#loc423 at #loc475)) +#loc641 = loc(callsite(#loc451 at #loc475)) +#loc642 = loc(callsite(#loc468 at #loc475)) +#loc643 = loc(callsite(#loc367 at #loc481)) +#loc644 = loc(callsite(#loc369 at #loc481)) +#loc645 = loc(callsite(#loc370 at #loc481)) +#loc646 = loc(callsite(#loc371 at #loc481)) +#loc647 = loc(callsite(#loc372 at #loc481)) +#loc648 = loc(callsite(#loc60 at #loc481)) +#loc649 = loc(callsite(#loc61 at #loc481)) +#loc650 = loc(callsite(#loc370 at #loc482)) +#loc651 = loc(callsite(#loc372 at #loc482)) +#loc652 = loc(callsite(#loc61 at #loc482)) +#loc653 = loc(callsite(#loc495 at #loc192)) +#loc654 = loc(callsite(#loc496 at #loc192)) +#loc655 = loc(callsite(#loc497 at #loc192)) +#loc656 = loc(callsite(#loc498 at #loc192)) +#loc657 = loc(callsite(#loc499 at #loc192)) +#loc658 = loc(callsite(#loc500 at #loc192)) +#loc659 = loc(callsite(#loc501 at #loc192)) +#loc660 = loc(callsite(#loc502 at #loc192)) +#loc661 = loc(callsite(#loc503 at #loc192)) +#loc662 = loc(callsite(#loc504 at #loc192)) +#loc663 = loc(callsite(#loc505 at #loc192)) +#loc664 = loc(callsite(#loc506 at #loc507)) +#loc665 = loc(callsite(#loc508 at #loc507)) +#loc666 = loc(callsite(#loc509 at #loc507)) +#loc667 = loc(callsite(#loc510 at #loc507)) +#loc668 = loc(callsite(#loc511 at #loc507)) +#loc669 = loc(callsite(#loc512 at #loc507)) +#loc670 = loc(callsite(#loc513 at #loc507)) +#loc671 = loc(callsite(#loc514 at #loc507)) +#loc672 = loc(callsite(#loc515 at #loc507)) +#loc673 = loc(callsite(#loc516 at #loc507)) +#loc674 = loc(callsite(#loc517 at #loc507)) +#loc675 = loc(callsite(#loc518 at #loc507)) +#loc676 = loc(callsite(#loc519 at #loc507)) +#loc677 = loc(callsite(#loc520 at #loc507)) +#loc678 = loc(callsite(#loc521 at #loc507)) +#loc679 = loc(callsite(#loc522 at #loc507)) +#loc680 = loc(callsite(#loc495 at #loc226)) +#loc681 = loc(callsite(#loc496 at #loc226)) +#loc682 = loc(callsite(#loc499 at #loc226)) +#loc683 = loc(callsite(#loc500 at #loc226)) +#loc684 = loc(callsite(#loc502 at #loc226)) +#loc685 = loc(callsite(#loc505 at #loc226)) +#loc686 = loc(callsite(#loc529 at #loc507)) +#loc687 = loc("dk"(#loc530)) +#loc688 = loc(callsite(#loc545 at #loc192)) +#loc689 = loc(callsite(#loc546 at #loc192)) +#loc690 = loc(callsite(#loc547 at #loc507)) +#loc691 = loc(callsite(#loc548 at #loc507)) +#loc692 = loc(callsite(#loc549 at #loc507)) +#loc693 = loc(callsite(#loc550 at #loc507)) +#loc694 = loc("dv"(#loc551)) +#loc695 = loc(callsite(#loc552 at #loc192)) +#loc696 = loc(callsite(#loc553 at #loc192)) +#loc697 = loc(callsite(#loc554 at #loc192)) +#loc698 = loc(callsite(#loc555 at #loc507)) +#loc699 = loc(callsite(#loc556 at #loc507)) +#loc700 = loc(callsite(#loc557 at #loc507)) +#loc701 = loc(callsite(#loc558 at #loc507)) +#loc702 = loc(callsite(#loc559 at #loc507)) +#loc703 = loc(callsite(#loc560 at #loc507)) +#loc704 = loc(callsite(#loc561 at #loc507)) +#loc705 = loc(callsite(#loc562 at #loc507)) +#loc706 = loc(callsite(#loc563 at #loc507)) +#loc707 = loc(callsite(#loc564 at #loc507)) +#loc708 = loc(callsite(#loc565 at #loc507)) +#loc709 = loc(callsite(#loc566 at #loc507)) +#loc710 = loc(callsite(#loc567 at #loc507)) +#loc711 = loc(callsite(#loc568 at #loc507)) +#loc712 = loc(callsite(#loc569 at #loc507)) +#loc713 = loc(callsite(#loc570 at #loc507)) +#loc714 = loc(callsite(#loc571 at #loc507)) +#loc715 = loc(callsite(#loc572 at #loc507)) +#loc716 = loc(callsite(#loc573 at #loc507)) +#loc717 = loc(callsite(#loc574 at #loc507)) +#loc718 = loc(callsite(#loc575 at #loc507)) +#loc719 = loc(callsite(#loc576 at #loc507)) +#loc720 = loc(callsite(#loc577 at #loc507)) +#loc721 = loc(callsite(#loc578 at #loc507)) +#loc722 = loc(callsite(#loc579 at #loc507)) +#loc723 = loc(callsite(#loc580 at #loc507)) +#loc724 = loc(callsite(#loc581 at #loc507)) +#loc725 = loc(callsite(#loc582 at #loc507)) +#loc726 = loc(callsite(#loc583 at #loc507)) +#loc727 = loc(callsite(#loc584 at #loc507)) +#loc728 = loc(callsite(#loc585 at #loc507)) +#loc729 = loc(callsite(#loc586 at #loc507)) +#loc730 = loc(callsite(#loc587 at #loc507)) +#loc731 = loc(callsite(#loc588 at #loc507)) +#loc732 = loc(callsite(#loc589 at #loc192)) +#loc733 = loc(callsite(#loc590 at #loc192)) +#loc734 = loc(callsite(#loc591 at #loc192)) +#loc735 = loc(callsite(#loc545 at #loc226)) +#loc736 = loc(callsite(#loc498 at #loc226)) +#loc737 = loc(callsite(#loc546 at #loc226)) +#loc738 = loc(callsite(#loc501 at #loc226)) +#loc739 = loc(callsite(#loc506 at #loc592)) +#loc740 = loc(callsite(#loc549 at #loc592)) +#loc741 = loc(callsite(#loc522 at #loc592)) +#loc742 = loc(callsite(#loc550 at #loc592)) +#loc743 = loc(callsite(#loc508 at #loc592)) +#loc744 = loc(callsite(#loc547 at #loc592)) +#loc745 = loc(callsite(#loc548 at #loc592)) +#loc746 = loc(callsite(#loc552 at #loc226)) +#loc747 = loc(callsite(#loc553 at #loc226)) +#loc748 = loc(callsite(#loc554 at #loc226)) +#loc749 = loc(callsite(#loc555 at #loc592)) +#loc750 = loc(callsite(#loc556 at #loc592)) +#loc751 = loc(callsite(#loc557 at #loc592)) +#loc752 = loc(callsite(#loc529 at #loc592)) +#loc753 = loc(callsite(#loc558 at #loc592)) +#loc754 = loc(callsite(#loc560 at #loc592)) +#loc755 = loc(callsite(#loc574 at #loc592)) +#loc756 = loc(callsite(#loc575 at #loc592)) +#loc757 = loc(callsite(#loc576 at #loc592)) +#loc758 = loc(callsite(#loc577 at #loc592)) +#loc759 = loc(callsite(#loc578 at #loc592)) +#loc760 = loc(callsite(#loc579 at #loc592)) +#loc761 = loc(callsite(#loc580 at #loc592)) +#loc762 = loc(callsite(#loc581 at #loc592)) +#loc763 = loc(callsite(#loc582 at #loc592)) +#loc764 = loc(callsite(#loc583 at #loc592)) +#loc765 = loc(callsite(#loc584 at #loc592)) +#loc766 = loc(callsite(#loc585 at #loc592)) +#loc767 = loc(callsite(#loc587 at #loc592)) +#loc768 = loc(callsite(#loc588 at #loc592)) +#loc769 = loc(callsite(#loc589 at #loc226)) +#loc770 = loc(callsite(#loc590 at #loc226)) +#loc771 = loc(callsite(#loc591 at #loc226)) +#loc772 = loc(callsite(#loc12 at #loc619)) +#loc773 = loc(callsite(#loc14 at #loc619)) +#loc774 = loc(callsite(#loc401 at #loc622)) +#loc775 = loc(callsite(#loc403 at #loc622)) +#loc776 = loc(callsite(#loc404 at #loc622)) +#loc777 = loc(callsite(#loc405 at #loc622)) +#loc778 = loc(callsite(#loc406 at #loc622)) +#loc779 = loc(callsite(#loc407 at #loc622)) +#loc780 = loc(callsite(#loc408 at #loc622)) +#loc781 = loc(callsite(#loc409 at #loc622)) +#loc782 = loc(callsite(#loc410 at #loc622)) +#loc783 = loc(callsite(#loc411 at #loc622)) +#loc784 = loc(callsite(#loc412 at #loc622)) +#loc785 = loc(callsite(#loc413 at #loc622)) +#loc786 = loc(callsite(#loc414 at #loc622)) +#loc787 = loc(callsite(#loc415 at #loc622)) +#loc788 = loc(callsite(#loc416 at #loc622)) +#loc789 = loc(callsite(#loc417 at #loc622)) +#loc790 = loc(callsite(#loc418 at #loc622)) +#loc791 = loc(callsite(#loc419 at #loc622)) +#loc792 = loc("kT_ptrs"(#loc623)) +#loc793 = loc(callsite(#loc424 at #loc622)) +#loc794 = loc(callsite(#loc425 at #loc622)) +#loc795 = loc(callsite(#loc426 at #loc622)) +#loc796 = loc(callsite(#loc427 at #loc622)) +#loc797 = loc(callsite(#loc428 at #loc622)) +#loc798 = loc(callsite(#loc429 at #loc622)) +#loc799 = loc(callsite(#loc430 at #loc622)) +#loc800 = loc(callsite(#loc431 at #loc622)) +#loc801 = loc(callsite(#loc432 at #loc622)) +#loc802 = loc(callsite(#loc433 at #loc622)) +#loc803 = loc(callsite(#loc434 at #loc622)) +#loc804 = loc(callsite(#loc435 at #loc622)) +#loc805 = loc(callsite(#loc436 at #loc622)) +#loc806 = loc(callsite(#loc437 at #loc622)) +#loc807 = loc(callsite(#loc438 at #loc622)) +#loc808 = loc(callsite(#loc439 at #loc622)) +#loc809 = loc(callsite(#loc440 at #loc622)) +#loc810 = loc(callsite(#loc441 at #loc622)) +#loc811 = loc(callsite(#loc442 at #loc622)) +#loc812 = loc(callsite(#loc443 at #loc622)) +#loc813 = loc(callsite(#loc444 at #loc622)) +#loc814 = loc(callsite(#loc445 at #loc622)) +#loc815 = loc(callsite(#loc446 at #loc622)) +#loc816 = loc(callsite(#loc447 at #loc622)) +#loc817 = loc(callsite(#loc448 at #loc622)) +#loc818 = loc(callsite(#loc449 at #loc622)) +#loc819 = loc(callsite(#loc450 at #loc627)) +#loc820 = loc(callsite(#loc452 at #loc627)) +#loc821 = loc(callsite(#loc453 at #loc627)) +#loc822 = loc(callsite(#loc454 at #loc627)) +#loc823 = loc(callsite(#loc455 at #loc627)) +#loc824 = loc(callsite(#loc456 at #loc627)) +#loc825 = loc(callsite(#loc457 at #loc627)) +#loc826 = loc(callsite(#loc458 at #loc627)) +#loc827 = loc(callsite(#loc459 at #loc627)) +#loc828 = loc(callsite(#loc460 at #loc627)) +#loc829 = loc(callsite(#loc461 at #loc627)) +#loc830 = loc(callsite(#loc462 at #loc627)) +#loc831 = loc(callsite(#loc463 at #loc627)) +#loc832 = loc(callsite(#loc464 at #loc627)) +#loc833 = loc(callsite(#loc465 at #loc627)) +#loc834 = loc(callsite(#loc466 at #loc627)) +#loc835 = loc(callsite(#loc467 at #loc627)) +#loc836 = loc(callsite(#loc401 at #loc637)) +#loc837 = loc(callsite(#loc419 at #loc637)) +#loc838 = loc(callsite(#loc424 at #loc637)) +#loc839 = loc(callsite(#loc425 at #loc637)) +#loc840 = loc(callsite(#loc426 at #loc637)) +#loc841 = loc(callsite(#loc428 at #loc637)) +#loc842 = loc(callsite(#loc442 at #loc637)) +#loc843 = loc(callsite(#loc416 at #loc637)) +#loc844 = loc(callsite(#loc443 at #loc637)) +#loc845 = loc(callsite(#loc444 at #loc637)) +#loc846 = loc(callsite(#loc418 at #loc637)) +#loc847 = loc(callsite(#loc445 at #loc637)) +#loc848 = loc(callsite(#loc446 at #loc637)) +#loc849 = loc(callsite(#loc448 at #loc637)) +#loc850 = loc(callsite(#loc449 at #loc637)) +#loc851 = loc(callsite(#loc450 at #loc641)) +#loc852 = loc(callsite(#loc452 at #loc641)) +#loc853 = loc(callsite(#loc453 at #loc641)) +#loc854 = loc(callsite(#loc454 at #loc641)) +#loc855 = loc(callsite(#loc455 at #loc641)) +#loc856 = loc(callsite(#loc456 at #loc641)) +#loc857 = loc(callsite(#loc457 at #loc641)) +#loc858 = loc(callsite(#loc458 at #loc641)) +#loc859 = loc(callsite(#loc459 at #loc641)) +#loc860 = loc(callsite(#loc460 at #loc641)) +#loc861 = loc(callsite(#loc461 at #loc641)) +#loc862 = loc(callsite(#loc462 at #loc641)) +#loc863 = loc(callsite(#loc463 at #loc641)) +#loc864 = loc(callsite(#loc464 at #loc641)) +#loc865 = loc(callsite(#loc465 at #loc641)) +#loc866 = loc(callsite(#loc466 at #loc641)) +#loc867 = loc(callsite(#loc467 at #loc641)) +#loc868 = loc(callsite(#loc12 at #loc661)) +#loc869 = loc(callsite(#loc14 at #loc661)) +#loc870 = loc(callsite(#loc90 at #loc664)) +#loc871 = loc(callsite(#loc93 at #loc666)) +#loc872 = loc(callsite(#loc60 at #loc679)) +#loc873 = loc(callsite(#loc110 at #loc664)) +#loc874 = loc(callsite(#loc61 at #loc679)) +#loc875 = loc("offs_m1"(#loc694)) +#loc876 = loc(callsite(#loc116 at #loc664)) +#loc877 = loc(callsite(#loc253 at #loc679)) +#loc878 = loc(callsite(#loc93 at #loc702)) +#loc879 = loc(callsite(#loc450 at #loc732)) +#loc880 = loc(callsite(#loc452 at #loc732)) +#loc881 = loc(callsite(#loc453 at #loc732)) +#loc882 = loc(callsite(#loc454 at #loc732)) +#loc883 = loc(callsite(#loc455 at #loc732)) +#loc884 = loc(callsite(#loc456 at #loc732)) +#loc885 = loc(callsite(#loc457 at #loc732)) +#loc886 = loc(callsite(#loc458 at #loc732)) +#loc887 = loc(callsite(#loc459 at #loc732)) +#loc888 = loc(callsite(#loc460 at #loc732)) +#loc889 = loc(callsite(#loc461 at #loc732)) +#loc890 = loc(callsite(#loc462 at #loc732)) +#loc891 = loc(callsite(#loc463 at #loc732)) +#loc892 = loc(callsite(#loc464 at #loc732)) +#loc893 = loc(callsite(#loc465 at #loc732)) +#loc894 = loc(callsite(#loc466 at #loc732)) +#loc895 = loc(callsite(#loc467 at #loc732)) +#loc896 = loc(callsite(#loc110 at #loc739)) +#loc897 = loc(callsite(#loc61 at #loc741)) +#loc898 = loc(callsite(#loc90 at #loc739)) +#loc899 = loc(callsite(#loc60 at #loc741)) +#loc900 = loc(callsite(#loc116 at #loc739)) +#loc901 = loc(callsite(#loc253 at #loc741)) +#loc902 = loc(callsite(#loc450 at #loc769)) +#loc903 = loc(callsite(#loc452 at #loc769)) +#loc904 = loc(callsite(#loc453 at #loc769)) +#loc905 = loc(callsite(#loc454 at #loc769)) +#loc906 = loc(callsite(#loc455 at #loc769)) +#loc907 = loc(callsite(#loc456 at #loc769)) +#loc908 = loc(callsite(#loc457 at #loc769)) +#loc909 = loc(callsite(#loc458 at #loc769)) +#loc910 = loc(callsite(#loc459 at #loc769)) +#loc911 = loc(callsite(#loc460 at #loc769)) +#loc912 = loc(callsite(#loc461 at #loc769)) +#loc913 = loc(callsite(#loc462 at #loc769)) +#loc914 = loc(callsite(#loc463 at #loc769)) +#loc915 = loc(callsite(#loc464 at #loc769)) +#loc916 = loc(callsite(#loc465 at #loc769)) +#loc917 = loc(callsite(#loc466 at #loc769)) +#loc918 = loc(callsite(#loc467 at #loc769)) +#loc919 = loc(callsite(#loc90 at #loc774)) +#loc920 = loc(callsite(#loc93 at #loc775)) +#loc921 = loc(callsite(#loc110 at #loc774)) +#loc922 = loc(callsite(#loc110 at #loc791)) +#loc923 = loc("vT_ptrs"(#loc792)) +#loc924 = loc(callsite(#loc116 at #loc774)) +#loc925 = loc(callsite(#loc93 at #loc796)) +#loc926 = loc(callsite(#loc110 at #loc836)) +#loc927 = loc(callsite(#loc110 at #loc837)) +#loc928 = loc(callsite(#loc90 at #loc836)) +#loc929 = loc(callsite(#loc116 at #loc836)) +#loc930 = loc("qT_ptrs"(#loc875)) +#loc931 = loc(callsite(#loc923 at #loc390)) +#loc932 = loc(callsite(#loc923 at #loc475)) +#loc933 = loc("do_ptrs"(#loc930)) +#loc934 = loc(callsite(#loc933 at #loc192)) +#loc935 = loc(callsite(#loc933 at #loc226)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttir new file mode 100644 index 0000000000000000000000000000000000000000..a19f2e90d97d9103f43bccc97e7241321e2984ce --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/7XYCMOPGDICSBBN4N46UWO2BLUA2ZH4YYLO2F5SVEXNOEL6NVJRQ/triton_tem_fused_mul_1.ttir @@ -0,0 +1,1542 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":18:0) +#loc306 = loc("arg_Q"(#loc)) +#loc307 = loc("arg_K"(#loc)) +#loc308 = loc("arg_V"(#loc)) +#loc309 = loc("arg_LSE"(#loc)) +#loc310 = loc("arg_DELTA"(#loc)) +#loc311 = loc("arg_DO"(#loc)) +#loc312 = loc("arg_DQ"(#loc)) +#loc313 = loc("arg_DV"(#loc)) +#loc314 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc315 = loc("arg_KV_IDX"(#loc)) +#loc316 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc317 = loc("arg_Q_IDX"(#loc)) +#loc318 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc319 = loc("arg_FULL_KV_IDX"(#loc)) +#loc320 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc321 = loc("arg_FULL_Q_IDX"(#loc)) +#loc322 = loc("out_ptr0"(#loc)) +#loc323 = loc("ks0"(#loc)) +#loc324 = loc("ks1"(#loc)) +#loc325 = loc("ks2"(#loc)) +#loc326 = loc("ks3"(#loc)) +#loc327 = loc("ks4"(#loc)) +#loc328 = loc("ks5"(#loc)) +#loc329 = loc("ks6"(#loc)) +#loc330 = loc("ks7"(#loc)) +module { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc)), %ks2: i32 loc("ks2"(#loc)), %ks3: i32 loc("ks3"(#loc)), %ks4: i32 loc("ks4"(#loc)), %ks5: i32 loc("ks5"(#loc)), %ks6: i32 loc("ks6"(#loc)), %ks7: i32 loc("ks7"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<64x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<4096> : tensor<1x64xi32> loc(#loc1) + %cst_1 = arith.constant dense<1024> : tensor<1x64xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x128xbf16> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1) + %cst_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_6 = arith.constant dense<16> : tensor<1x64xi32> loc(#loc1) + %cst_7 = arith.constant dense<16> : tensor<128x1xi32> loc(#loc1) + %cst_8 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1) + %cst_9 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc1) + %cst_10 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc1) + %cst_11 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1) + %cst_12 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1) + %cst_13 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1) + %cst_14 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1) + %cst_15 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1) + %cst_16 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> loc(#loc1) + %cst_17 = arith.constant dense<0.000000e+00> : tensor<128x128xbf16> loc(#loc1) + %cst_18 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_19 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc1) + %cst_20 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc1) + %cst_21 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc1) + %cst_22 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc1) + %cst_23 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc1) + %cst_24 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %HQ = arith.constant 32 : i32 loc(#loc331) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %0 = arith.muli %ks0, %c4096_i32 : i32 loc(#loc3) + %1 = arith.cmpi sle, %ks0, %c1_i32 : i32 loc(#loc4) + %2 = arith.extui %1 : i1 to i32 loc(#loc5) + %3 = arith.cmpi sgt, %ks0, %c1_i32 : i32 loc(#loc6) + %4 = arith.extui %3 : i1 to i32 loc(#loc7) + %5 = arith.muli %ks0, %4 : i32 loc(#loc7) + %6 = arith.addi %2, %5 : i32 loc(#loc8) + %7 = arith.muli %6, %c4096_i32 : i32 loc(#loc9) + %8 = arith.muli %6, %c128_i32 : i32 loc(#loc10) + %9 = arith.muli %ks1, %c1024_i32 : i32 loc(#loc11) + %pid = tt.get_program_id x : i32 loc(#loc332) + %NUM_KV_BLOCKS = arith.addi %ks1, %c127_i32 : i32 loc(#loc602) + %NUM_KV_BLOCKS_25 = arith.divsi %NUM_KV_BLOCKS, %c128_i32 : i32 loc(#loc603) + %NUM_Q_BLOCKS = arith.addi %ks0, %c127_i32 : i32 loc(#loc604) + %NUM_Q_BLOCKS_26 = arith.divsi %NUM_Q_BLOCKS, %c128_i32 : i32 loc(#loc605) + %off_zq = tt.get_program_id y : i32 loc(#loc335) + %off_hkv = tt.get_program_id z : i32 loc(#loc336) + %k_adj = arith.muli %off_hkv, %c128_i32 : i32 loc(#loc337) + %k_adj_27 = arith.extsi %k_adj : i32 to i64 loc(#loc338) + %dv_adj = arith.muli %9, %off_zq : i32 loc(#loc339) + %dv_adj_28 = arith.addi %k_adj, %dv_adj : i32 loc(#loc340) + %dv_adj_29 = arith.extsi %dv_adj_28 : i32 to i64 loc(#loc341) + %K = tt.addptr %arg_K, %k_adj_27 : !tt.ptr, i64 loc(#loc342) + %V = tt.addptr %arg_V, %k_adj_27 : !tt.ptr, i64 loc(#loc343) + %DV = tt.addptr %arg_DV, %dv_adj_29 : !tt.ptr, i64 loc(#loc344) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc345) + %10 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS_25 : i32 loc(#loc28) + scf.if %10 { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS_25 : i32 loc(#loc346) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS_26 : i32 loc(#loc347) + %off_hq2_30 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc348) + %off_hq2_31 = arith.addi %off_hq2, %off_hq2_30 : i32 loc(#loc349) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS_26 : i32 loc(#loc350) + %sparse_kv_idx_offset = arith.muli %start_m2_block, %ks4 : i32 loc(#loc351) + %q_adj2 = arith.muli %off_hq2_31, %c128_i32 : i32 loc(#loc352) + %q_adj2_32 = arith.muli %0, %off_zq : i32 loc(#loc353) + %q_adj2_33 = arith.addi %q_adj2, %q_adj2_32 : i32 loc(#loc354) + %q_adj2_34 = arith.extsi %q_adj2_33 : i32 to i64 loc(#loc355) + %do_adj2 = arith.muli %8, %off_hq2_31 : i32 loc(#loc356) + %do_adj2_35 = arith.muli %7, %off_zq : i32 loc(#loc357) + %do_adj2_36 = arith.addi %do_adj2, %do_adj2_35 : i32 loc(#loc358) + %do_adj2_37 = arith.extsi %do_adj2_36 : i32 to i64 loc(#loc359) + %off_chz2 = arith.muli %off_zq, %HQ : i32 loc(#loc360) + %off_chz2_38 = arith.addi %off_chz2, %off_hq2_31 : i32 loc(#loc361) + %off_chz2_39 = arith.muli %off_chz2_38, %ks0 : i32 loc(#loc362) + %off_chz2_40 = arith.extsi %off_chz2_39 : i32 to i64 loc(#loc363) + %Q2 = tt.addptr %arg_Q, %q_adj2_34 : !tt.ptr, i64 loc(#loc364) + %DO2 = tt.addptr %arg_DO, %do_adj2_37 : !tt.ptr, i64 loc(#loc365) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_34 : !tt.ptr, i64 loc(#loc366) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_40 : !tt.ptr, i64 loc(#loc367) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_40 : !tt.ptr, i64 loc(#loc368) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc369) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32> loc(#loc370) + %offs_m2_41 = arith.addi %offs_m2, %offs_k : tensor<128xi32> loc(#loc370) + %ptr = tt.expand_dims %offs_m2_41 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc606) + %ptr_42 = arith.muli %ptr, %cst_22 : tensor<128x1xi32> loc(#loc607) + %ptr_43 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc608) + %ptr_44 = tt.addptr %ptr_43, %ptr_42 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc608) + %ptr_45 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc609) + %ptr_46 = tt.broadcast %ptr_44 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc610) + %ptr_47 = tt.broadcast %ptr_45 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc610) + %ptr_48 = tt.addptr %ptr_46, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc610) + %q = tt.splat %ks0 : i32 -> tensor<128x1xi32> loc(#loc611) + %q_49 = arith.cmpi slt, %ptr, %q : tensor<128x1xi32> loc(#loc611) + %q_50 = tt.broadcast %q_49 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc612) + %q_51 = tt.load %ptr_48, %q_50, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc612) + %ptr_52 = arith.muli %ptr, %cst_2 : tensor<128x1xi32> loc(#loc613) + %ptr_53 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc614) + %ptr_54 = tt.addptr %ptr_53, %ptr_52 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc614) + %ptr_55 = tt.broadcast %ptr_54 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc615) + %ptr_56 = tt.addptr %ptr_55, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc615) + %do = tt.load %ptr_56, %q_50, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc616) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc378) + %Di_57 = arith.cmpi slt, %offs_m2_41, %Di : tensor<128xi32> loc(#loc378) + %Di_58 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc379) + %Di_59 = tt.addptr %Di_58, %offs_m2_41 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc379) + %Di_60 = tt.load %Di_59, %Di_57 : tensor<128x!tt.ptr> loc(#loc380) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc381) + %lse_61 = tt.addptr %lse, %offs_m2_41 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc381) + %lse_62 = tt.load %lse_61, %Di_57 : tensor<128x!tt.ptr> loc(#loc382) + %lse_63 = arith.cmpf oeq, %lse_62, %cst_24 : tensor<128xf32> loc(#loc383) + %lse_64 = arith.select %lse_63, %cst_23, %lse_62 : tensor<128xi1>, tensor<128xf32> loc(#loc384) + %lse_65 = tt.expand_dims %lse_64 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc385) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset : !tt.ptr, i32 loc(#loc386) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc387) + %kv_start_66 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc388) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc389) + %sparse_kv_num_blocks_67 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc390) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc391) + %offs_n2_68 = tt.splat %kv_start_66 : i32 -> tensor<64xi32> loc(#loc392) + %offs_n2_69 = arith.addi %offs_n2_68, %offs_n2 : tensor<64xi32> loc(#loc392) + %kT_ptrs = tt.expand_dims %offs_n2_69 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc617) + %kT_ptrs_70 = arith.muli %kT_ptrs, %cst_1 : tensor<1x64xi32> loc(#loc618) + %kT_ptrs_71 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc619) + %kT_ptrs_72 = tt.addptr %kT_ptrs_71, %kT_ptrs_70 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc619) + %kT_ptrs_73 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc620) + %kT_ptrs_74 = tt.broadcast %kT_ptrs_72 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc621) + %kT_ptrs_75 = tt.broadcast %kT_ptrs_73 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc621) + %kT_ptrs_76 = tt.addptr %kT_ptrs_74, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc621) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc622) + %vT_ptrs_77 = tt.addptr %vT_ptrs, %kT_ptrs_70 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc622) + %vT_ptrs_78 = tt.broadcast %vT_ptrs_77 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc623) + %vT_ptrs_79 = tt.addptr %vT_ptrs_78, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc623) + %hi = arith.muli %sparse_kv_num_blocks_67, %c2_i32 : i32 loc(#loc624) + %hi_80 = arith.addi %ks1, %c63_i32 : i32 loc(#loc780) + %hi_81 = arith.divsi %hi_80, %c64_i32 : i32 loc(#loc781) + %hi_82 = arith.maxsi %hi_81, %c1_i32 : i32 loc(#loc626) + %hi_83 = arith.minsi %hi, %hi_82 : i32 loc(#loc627) + %vT_ptrs_84:4 = scf.for %start_n = %c0_i32 to %hi_83 step %c1_i32 iter_args(%dq_106 = %cst_18, %offs_n2_107 = %offs_n2_69, %kT_ptrs_108 = %kT_ptrs_76, %vT_ptrs_109 = %vT_ptrs_79) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.expand_dims %offs_n2_107 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc929) + %kT_110 = tt.splat %ks1 : i32 -> tensor<1x64xi32> loc(#loc930) + %kT_111 = arith.cmpi slt, %kT, %kT_110 : tensor<1x64xi32> loc(#loc930) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc931) + %kT_113 = tt.load %kT_ptrs_108, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc931) + %qk = tt.dot %q_51, %kT_113, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc784) + %qk_114 = arith.mulf %qk, %cst_14 : tensor<128x64xf32> loc(#loc785) + %n = arith.remsi %kT, %kT_110 : tensor<1x64xi32> loc(#loc932) + %m = arith.remsi %ptr, %q : tensor<128x1xi32> loc(#loc933) + %post_mod_scores = arith.select %kT_112, %qk_114, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc788) + %tmp3 = arith.cmpi slt, %m, %cst_11 : tensor<128x1xi32> loc(#loc789) + %tmp5 = tt.broadcast %n : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc790) + %tmp5_115 = tt.broadcast %m : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc790) + %tmp5_116 = arith.cmpi sle, %tmp5, %tmp5_115 : tensor<128x64xi32> loc(#loc790) + %tmp6 = tt.broadcast %tmp3 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc791) + %tmp6_117 = arith.andi %tmp6, %tmp5_116 : tensor<128x64xi1> loc(#loc791) + %tmp7 = arith.cmpi sge, %m, %cst_11 : tensor<128x1xi32> loc(#loc792) + %tmp8 = arith.cmpi slt, %n, %cst_12 : tensor<1x64xi32> loc(#loc793) + %tmp9 = tt.broadcast %tmp7 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc794) + %tmp9_118 = tt.broadcast %tmp8 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc794) + %tmp9_119 = arith.andi %tmp9, %tmp9_118 : tensor<128x64xi1> loc(#loc794) + %tmp10 = arith.extui %tmp8 : tensor<1x64xi1> to tensor<1x64xi32> loc(#loc795) + %tmp10_120 = arith.cmpi eq, %tmp10, %cst_12 : tensor<1x64xi32> loc(#loc795) + %tmp11 = tt.broadcast %tmp10_120 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc796) + %tmp11_121 = arith.andi %tmp9, %tmp11 : tensor<128x64xi1> loc(#loc796) + %tmp14 = arith.remsi %m, %cst_7 : tensor<128x1xi32> loc(#loc797) + %tmp14_122 = arith.cmpi ne, %tmp14, %cst_11 : tensor<128x1xi32> loc(#loc798) + %tmp14_123 = arith.divsi %m, %cst_7 : tensor<128x1xi32> loc(#loc799) + %tmp14_124 = arith.subi %tmp14_123, %cst_10 : tensor<128x1xi32> loc(#loc800) + %tmp14_125 = arith.select %tmp14_122, %tmp14_124, %tmp14_123 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc801) + %tmp14_126 = arith.select %tmp3, %tmp14_125, %tmp14_123 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc802) + %tmp16 = arith.remsi %n, %cst_6 : tensor<1x64xi32> loc(#loc803) + %tmp16_127 = arith.cmpi ne, %tmp16, %cst_12 : tensor<1x64xi32> loc(#loc804) + %tmp16_128 = arith.divsi %n, %cst_6 : tensor<1x64xi32> loc(#loc805) + %tmp16_129 = arith.subi %tmp16_128, %cst_9 : tensor<1x64xi32> loc(#loc806) + %tmp16_130 = arith.select %tmp16_127, %tmp16_129, %tmp16_128 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc807) + %tmp16_131 = arith.select %tmp8, %tmp16_130, %tmp16_128 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc808) + %tmp17 = tt.broadcast %tmp14_126 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc809) + %tmp17_132 = tt.broadcast %tmp16_131 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc809) + %tmp17_133 = arith.cmpi eq, %tmp17, %tmp17_132 : tensor<128x64xi32> loc(#loc809) + %tmp18 = arith.andi %tmp11_121, %tmp17_133 : tensor<128x64xi1> loc(#loc810) + %tmp19 = arith.ori %tmp9_119, %tmp18 : tensor<128x64xi1> loc(#loc811) + %tmp20 = arith.ori %tmp6_117, %tmp19 : tensor<128x64xi1> loc(#loc812) + %post_mod_scores_134 = arith.select %tmp20, %post_mod_scores, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc813) + %post_mod_scores_135 = arith.mulf %post_mod_scores_134, %cst_8 : tensor<128x64xf32> loc(#loc814) + %p = tt.broadcast %lse_65 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc815) + %p_136 = arith.subf %post_mod_scores_135, %p : tensor<128x64xf32> loc(#loc815) + %p_137 = math.exp2 %p_136 : tensor<128x64xf32> loc(#loc816) + %vT = tt.load %vT_ptrs_109, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc934) + %dp = tt.dot %do, %vT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc818) + %ds = tt.expand_dims %Di_60 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc819) + %ds_138 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc820) + %ds_139 = arith.subf %dp, %ds_138 : tensor<128x64xf32> loc(#loc820) + %ds_140 = arith.mulf %p_137, %ds_139 : tensor<128x64xf32> loc(#loc821) + %grad_scores = arith.select %kT_112, %ds_140, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc822) + %ds_141 = arith.select %tmp20, %grad_scores, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc823) + %ds_142 = arith.truncf %ds_141 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc824) + %dq_143 = tt.trans %kT_113 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc825) + %dq_144 = tt.dot %ds_142, %dq_143, %dq_106, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc826) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc827) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc828) + %cur_block_145 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc829) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc830) + %next_block_146 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_67 : i32 loc(#loc831) + %next_block_147 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc832) + %next_block_148 = tt.load %next_block_147, %next_block_146 evictionPolicy = evict_last : !tt.ptr loc(#loc833) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc834) + %needs_jump_149 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc835) + %needs_jump_150 = arith.cmpi eq, %needs_jump_149, %c0_i32 : i32 loc(#loc836) + %jump_to_block = arith.subi %next_block_148, %cur_block_145 : i32 loc(#loc837) + %jump_to_block_151 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc838) + %jump_to_block_152 = arith.subi %jump_to_block_151, %c64_i32 : i32 loc(#loc839) + %offset = arith.extui %needs_jump_150 : i1 to i32 loc(#loc840) + %offset_153 = arith.muli %jump_to_block_152, %offset : i32 loc(#loc840) + %offset_154 = arith.subi %c1_i32, %offset : i32 loc(#loc841) + %offset_155 = arith.muli %offset_154, %c64_i32 : i32 loc(#loc842) + %offset_156 = arith.addi %offset_153, %offset_155 : i32 loc(#loc843) + %kT_ptrs_157 = arith.muli %offset_156, %c1024_i32 : i32 loc(#loc631) + %kT_ptrs_158 = tt.splat %kT_ptrs_157 : i32 -> tensor<128x64xi32> loc(#loc632) + %kT_ptrs_159 = tt.addptr %kT_ptrs_108, %kT_ptrs_158 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc632) + %vT_ptrs_160 = tt.addptr %vT_ptrs_109, %kT_ptrs_158 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc633) + %offs_n2_161 = tt.splat %offset_156 : i32 -> tensor<64xi32> loc(#loc634) + %offs_n2_162 = arith.addi %offs_n2_107, %offs_n2_161 : tensor<64xi32> loc(#loc634) + scf.yield %dq_144, %offs_n2_162, %kT_ptrs_159, %vT_ptrs_160 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc635) + } loc(#loc940) + %kv_indices_85 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset : !tt.ptr, i32 loc(#loc473) + %kv_start_86 = tt.load %kv_indices_85 : !tt.ptr loc(#loc474) + %kv_start_87 = arith.muli %kv_start_86, %c128_i32 : i32 loc(#loc475) + %sparse_kv_num_blocks_88 = tt.addptr %arg_FULL_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc476) + %sparse_kv_num_blocks_89 = tt.load %sparse_kv_num_blocks_88 : !tt.ptr loc(#loc477) + %offs_n2_90 = tt.splat %kv_start_87 : i32 -> tensor<64xi32> loc(#loc478) + %offs_n2_91 = arith.addi %offs_n2_90, %offs_n2 : tensor<64xi32> loc(#loc478) + %kT_ptrs_92 = tt.expand_dims %offs_n2_91 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc636) + %kT_ptrs_93 = arith.muli %kT_ptrs_92, %cst_1 : tensor<1x64xi32> loc(#loc637) + %kT_ptrs_94 = tt.addptr %kT_ptrs_71, %kT_ptrs_93 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc638) + %kT_ptrs_95 = tt.broadcast %kT_ptrs_94 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc639) + %kT_ptrs_96 = tt.addptr %kT_ptrs_95, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc639) + %vT_ptrs_97 = tt.addptr %vT_ptrs, %kT_ptrs_93 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc640) + %vT_ptrs_98 = tt.broadcast %vT_ptrs_97 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc641) + %vT_ptrs_99 = tt.addptr %vT_ptrs_98, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc641) + %hi_100 = arith.muli %sparse_kv_num_blocks_89, %c2_i32 : i32 loc(#loc642) + %hi_101 = arith.minsi %hi_100, %hi_82 : i32 loc(#loc643) + %vT_ptrs_102:4 = scf.for %start_n = %c0_i32 to %hi_101 step %c1_i32 iter_args(%dq_106 = %vT_ptrs_84#0, %offs_n2_107 = %offs_n2_91, %kT_ptrs_108 = %kT_ptrs_96, %vT_ptrs_109 = %vT_ptrs_99) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.expand_dims %offs_n2_107 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc935) + %kT_110 = tt.splat %ks1 : i32 -> tensor<1x64xi32> loc(#loc936) + %kT_111 = arith.cmpi slt, %kT, %kT_110 : tensor<1x64xi32> loc(#loc936) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc937) + %kT_113 = tt.load %kT_ptrs_108, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc937) + %qk = tt.dot %q_51, %kT_113, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc845) + %qk_114 = arith.mulf %qk, %cst_14 : tensor<128x64xf32> loc(#loc846) + %post_mod_scores = arith.select %kT_112, %qk_114, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc847) + %post_mod_scores_115 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32> loc(#loc848) + %p = tt.broadcast %lse_65 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc849) + %p_116 = arith.subf %post_mod_scores_115, %p : tensor<128x64xf32> loc(#loc849) + %p_117 = math.exp2 %p_116 : tensor<128x64xf32> loc(#loc850) + %vT = tt.load %vT_ptrs_109, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc938) + %dp = tt.dot %do, %vT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc852) + %ds = tt.expand_dims %Di_60 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc853) + %ds_118 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc854) + %ds_119 = arith.subf %dp, %ds_118 : tensor<128x64xf32> loc(#loc854) + %ds_120 = arith.mulf %p_117, %ds_119 : tensor<128x64xf32> loc(#loc855) + %grad_scores = arith.select %kT_112, %ds_120, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc856) + %ds_121 = arith.truncf %grad_scores : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc857) + %dq_122 = tt.trans %kT_113 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc858) + %dq_123 = tt.dot %ds_121, %dq_122, %dq_106, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc859) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc860) + %cur_block = tt.addptr %kv_indices_85, %cur_block_idx : !tt.ptr, i32 loc(#loc861) + %cur_block_124 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc862) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc863) + %next_block_125 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_89 : i32 loc(#loc864) + %next_block_126 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc865) + %next_block_127 = tt.load %next_block_126, %next_block_125 evictionPolicy = evict_last : !tt.ptr loc(#loc866) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc867) + %needs_jump_128 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc868) + %needs_jump_129 = arith.cmpi eq, %needs_jump_128, %c0_i32 : i32 loc(#loc869) + %jump_to_block = arith.subi %next_block_127, %cur_block_124 : i32 loc(#loc870) + %jump_to_block_130 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc871) + %jump_to_block_131 = arith.subi %jump_to_block_130, %c64_i32 : i32 loc(#loc872) + %offset = arith.extui %needs_jump_129 : i1 to i32 loc(#loc873) + %offset_132 = arith.muli %jump_to_block_131, %offset : i32 loc(#loc873) + %offset_133 = arith.subi %c1_i32, %offset : i32 loc(#loc874) + %offset_134 = arith.muli %offset_133, %c64_i32 : i32 loc(#loc875) + %offset_135 = arith.addi %offset_132, %offset_134 : i32 loc(#loc876) + %kT_ptrs_136 = arith.muli %offset_135, %c1024_i32 : i32 loc(#loc646) + %kT_ptrs_137 = tt.splat %kT_ptrs_136 : i32 -> tensor<128x64xi32> loc(#loc647) + %kT_ptrs_138 = tt.addptr %kT_ptrs_108, %kT_ptrs_137 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc647) + %vT_ptrs_139 = tt.addptr %vT_ptrs_109, %kT_ptrs_137 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc648) + %offs_n2_140 = tt.splat %offset_135 : i32 -> tensor<64xi32> loc(#loc649) + %offs_n2_141 = arith.addi %offs_n2_107, %offs_n2_140 : tensor<64xi32> loc(#loc649) + scf.yield %dq_123, %offs_n2_141, %kT_ptrs_138, %vT_ptrs_139 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc650) + } loc(#loc941) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc480) + %dq_ptrs_103 = tt.addptr %dq_ptrs, %ptr_42 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc480) + %dq_ptrs_104 = tt.broadcast %dq_ptrs_103 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc481) + %dq_ptrs_105 = tt.addptr %dq_ptrs_104, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc481) + %dq = arith.mulf %vT_ptrs_102#0, %cst_21 : tensor<128x128xf32> loc(#loc482) + %11 = arith.cmpi slt, %ptr_45, %cst_20 : tensor<1x128xi32> loc(#loc174) + %12 = tt.broadcast %11 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc175) + %13 = arith.andi %q_50, %12 : tensor<128x128xi1> loc(#loc175) + %14 = arith.truncf %dq : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc176) + tt.store %dq_ptrs_105, %14, %13 : tensor<128x128x!tt.ptr> loc(#loc176) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc483) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32> loc(#loc484) + %offs_n1_30 = arith.addi %offs_n1, %offs_k : tensor<128xi32> loc(#loc484) + %ptr = tt.expand_dims %offs_n1_30 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc651) + %ptr_31 = arith.muli %ptr, %cst_19 : tensor<128x1xi32> loc(#loc652) + %ptr_32 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc653) + %ptr_33 = tt.addptr %ptr_32, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc653) + %ptr_34 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc654) + %ptr_35 = tt.broadcast %ptr_33 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc655) + %ptr_36 = tt.broadcast %ptr_34 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc655) + %ptr_37 = tt.addptr %ptr_35, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc655) + %k = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc656) + %k_38 = arith.cmpi slt, %ptr, %k : tensor<128x1xi32> loc(#loc656) + %k_39 = tt.broadcast %k_38 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc657) + %k_40 = tt.load %ptr_37, %k_39, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc657) + %ptr_41 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc658) + %ptr_42 = tt.addptr %ptr_41, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc658) + %ptr_43 = tt.broadcast %ptr_42 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc659) + %ptr_44 = tt.addptr %ptr_43, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc659) + %v = tt.load %ptr_44, %k_39, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc660) + %dk:2 = scf.for %off_g = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%dv = %cst_18, %dk_49 = %cst_18) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc488) + %off_hq1_50 = arith.addi %off_hq1, %off_g : i32 loc(#loc489) + %q_adj1 = arith.muli %off_hq1_50, %c128_i32 : i32 loc(#loc490) + %q_adj1_51 = arith.muli %0, %off_zq : i32 loc(#loc491) + %q_adj1_52 = arith.addi %q_adj1, %q_adj1_51 : i32 loc(#loc492) + %q_adj1_53 = arith.extsi %q_adj1_52 : i32 to i64 loc(#loc493) + %do_adj1 = arith.muli %8, %off_hq1_50 : i32 loc(#loc494) + %do_adj1_54 = arith.muli %7, %off_zq : i32 loc(#loc495) + %do_adj1_55 = arith.addi %do_adj1, %do_adj1_54 : i32 loc(#loc496) + %do_adj1_56 = arith.extsi %do_adj1_55 : i32 to i64 loc(#loc497) + %off_chz1 = arith.muli %off_zq, %HQ : i32 loc(#loc498) + %off_chz1_57 = arith.addi %off_chz1, %off_hq1_50 : i32 loc(#loc499) + %off_chz1_58 = arith.muli %off_chz1_57, %ks0 : i32 loc(#loc500) + %off_chz1_59 = arith.extsi %off_chz1_58 : i32 to i64 loc(#loc501) + %Q1 = tt.addptr %arg_Q, %q_adj1_53 : !tt.ptr, i64 loc(#loc502) + %DO1 = tt.addptr %arg_DO, %do_adj1_56 : !tt.ptr, i64 loc(#loc503) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_59 : !tt.ptr, i64 loc(#loc504) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_59 : !tt.ptr, i64 loc(#loc505) + %sparse_q_idx_offset = arith.muli %pid, %ks6 : i32 loc(#loc506) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset : !tt.ptr, i32 loc(#loc507) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc508) + %q_start_60 = arith.muli %q_start, %c128_i32 : i32 loc(#loc509) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc510) + %sparse_q_num_blocks_61 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc511) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc512) + %offs_m1_62 = tt.splat %q_start_60 : i32 -> tensor<64xi32> loc(#loc513) + %offs_m1_63 = arith.addi %offs_m1_62, %offs_m1 : tensor<64xi32> loc(#loc513) + %qT_ptrs = tt.expand_dims %offs_m1_63 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc662) + %qT_ptrs_64 = arith.muli %qT_ptrs, %cst_0 : tensor<1x64xi32> loc(#loc663) + %qT_ptrs_65 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc664) + %qT_ptrs_66 = tt.addptr %qT_ptrs_65, %qT_ptrs_64 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc664) + %qT_ptrs_67 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc665) + %qT_ptrs_68 = tt.broadcast %qT_ptrs_66 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc666) + %qT_ptrs_69 = tt.broadcast %qT_ptrs_67 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc666) + %qT_ptrs_70 = tt.addptr %qT_ptrs_68, %qT_ptrs_69 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc666) + %do_ptrs = tt.expand_dims %offs_m1_63 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc667) + %do_ptrs_71 = arith.muli %do_ptrs, %cst : tensor<64x1xi32> loc(#loc668) + %do_ptrs_72 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc669) + %do_ptrs_73 = tt.addptr %do_ptrs_72, %do_ptrs_71 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc669) + %do_ptrs_74 = tt.broadcast %do_ptrs_73 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc670) + %do_ptrs_75 = tt.broadcast %ptr_34 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc670) + %do_ptrs_76 = tt.addptr %do_ptrs_74, %do_ptrs_75 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc670) + %hi = arith.muli %sparse_q_num_blocks_61, %c2_i32 : i32 loc(#loc671) + %hi_77 = arith.addi %ks0, %c63_i32 : i32 loc(#loc877) + %hi_78 = arith.divsi %hi_77, %c64_i32 : i32 loc(#loc878) + %hi_79 = arith.maxsi %hi_78, %c1_i32 : i32 loc(#loc673) + %hi_80 = arith.minsi %hi, %hi_79 : i32 loc(#loc674) + %do_ptrs_81:5 = scf.for %start_m = %c0_i32 to %hi_80 step %c1_i32 iter_args(%dk_102 = %dk_49, %dv_103 = %dv, %offs_m1_104 = %offs_m1_63, %qT_ptrs_105 = %qT_ptrs_70, %do_ptrs_106 = %do_ptrs_76) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.expand_dims %offs_m1_104 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc880) + %qT_107 = tt.splat %ks0 : i32 -> tensor<1x64xi32> loc(#loc881) + %qT_108 = arith.cmpi slt, %qT, %qT_107 : tensor<1x64xi32> loc(#loc881) + %qT_109 = tt.broadcast %qT_108 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc882) + %qT_110 = tt.load %qT_ptrs_105, %qT_109, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc882) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32> loc(#loc677) + %lse_111 = arith.cmpi slt, %offs_m1_104, %lse : tensor<64xi32> loc(#loc677) + %lse_112 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc678) + %lse_113 = tt.addptr %lse_112, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc678) + %lse_114 = tt.load %lse_113, %lse_111 : tensor<64x!tt.ptr> loc(#loc679) + %lse_115 = arith.cmpf oeq, %lse_114, %cst_5 : tensor<64xf32> loc(#loc680) + %lse_116 = arith.select %lse_115, %cst_4, %lse_114 : tensor<64xi1>, tensor<64xf32> loc(#loc681) + %qkT = tt.dot %k_40, %qT_110, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc682) + %qkT_117 = arith.mulf %qkT, %cst_14 : tensor<128x64xf32> loc(#loc683) + %m = arith.remsi %qT, %qT_107 : tensor<1x64xi32> loc(#loc883) + %n = arith.remsi %ptr, %k : tensor<128x1xi32> loc(#loc884) + %post_mod_scores = arith.select %qT_109, %qkT_117, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc686) + %tmp25 = arith.cmpi slt, %m, %cst_12 : tensor<1x64xi32> loc(#loc687) + %tmp27 = tt.broadcast %n : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc688) + %tmp27_118 = tt.broadcast %m : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc688) + %tmp27_119 = arith.cmpi sle, %tmp27, %tmp27_118 : tensor<128x64xi32> loc(#loc688) + %tmp28 = tt.broadcast %tmp25 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc689) + %tmp28_120 = arith.andi %tmp28, %tmp27_119 : tensor<128x64xi1> loc(#loc689) + %tmp29 = arith.cmpi sge, %m, %cst_12 : tensor<1x64xi32> loc(#loc690) + %tmp30 = arith.cmpi slt, %n, %cst_11 : tensor<128x1xi32> loc(#loc691) + %tmp31 = tt.broadcast %tmp29 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc692) + %tmp31_121 = tt.broadcast %tmp30 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc692) + %tmp31_122 = arith.andi %tmp31, %tmp31_121 : tensor<128x64xi1> loc(#loc692) + %tmp32 = arith.extui %tmp30 : tensor<128x1xi1> to tensor<128x1xi32> loc(#loc693) + %tmp32_123 = arith.cmpi eq, %tmp32, %cst_11 : tensor<128x1xi32> loc(#loc693) + %tmp33 = tt.broadcast %tmp32_123 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc694) + %tmp33_124 = arith.andi %tmp31, %tmp33 : tensor<128x64xi1> loc(#loc694) + %tmp36 = arith.remsi %m, %cst_6 : tensor<1x64xi32> loc(#loc695) + %tmp36_125 = arith.cmpi ne, %tmp36, %cst_12 : tensor<1x64xi32> loc(#loc696) + %tmp36_126 = arith.divsi %m, %cst_6 : tensor<1x64xi32> loc(#loc697) + %tmp36_127 = arith.subi %tmp36_126, %cst_9 : tensor<1x64xi32> loc(#loc698) + %tmp36_128 = arith.select %tmp36_125, %tmp36_127, %tmp36_126 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc699) + %tmp36_129 = arith.select %tmp25, %tmp36_128, %tmp36_126 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc700) + %tmp38 = arith.remsi %n, %cst_7 : tensor<128x1xi32> loc(#loc701) + %tmp38_130 = arith.cmpi ne, %tmp38, %cst_11 : tensor<128x1xi32> loc(#loc702) + %tmp38_131 = arith.divsi %n, %cst_7 : tensor<128x1xi32> loc(#loc703) + %tmp38_132 = arith.subi %tmp38_131, %cst_10 : tensor<128x1xi32> loc(#loc704) + %tmp38_133 = arith.select %tmp38_130, %tmp38_132, %tmp38_131 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc705) + %tmp38_134 = arith.select %tmp30, %tmp38_133, %tmp38_131 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc706) + %tmp39 = tt.broadcast %tmp36_129 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc707) + %tmp39_135 = tt.broadcast %tmp38_134 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc707) + %tmp39_136 = arith.cmpi eq, %tmp39, %tmp39_135 : tensor<128x64xi32> loc(#loc707) + %tmp40 = arith.andi %tmp33_124, %tmp39_136 : tensor<128x64xi1> loc(#loc708) + %tmp41 = arith.ori %tmp31_122, %tmp40 : tensor<128x64xi1> loc(#loc709) + %tmp42 = arith.ori %tmp28_120, %tmp41 : tensor<128x64xi1> loc(#loc710) + %post_mod_scores_137 = arith.select %tmp42, %post_mod_scores, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc711) + %post_mod_scores_138 = arith.mulf %post_mod_scores_137, %cst_8 : tensor<128x64xf32> loc(#loc712) + %pT = tt.expand_dims %lse_116 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc713) + %pT_139 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc714) + %pT_140 = arith.subf %post_mod_scores_138, %pT_139 : tensor<128x64xf32> loc(#loc714) + %pT_141 = math.exp2 %pT_140 : tensor<128x64xf32> loc(#loc715) + %do = tt.expand_dims %offs_m1_104 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc885) + %do_142 = tt.splat %ks0 : i32 -> tensor<64x1xi32> loc(#loc886) + %do_143 = arith.cmpi slt, %do, %do_142 : tensor<64x1xi32> loc(#loc886) + %do_144 = tt.broadcast %do_143 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc887) + %do_145 = tt.load %do_ptrs_106, %do_144, %cst_3 : tensor<64x128x!tt.ptr> loc(#loc887) + %dv_146 = arith.truncf %pT_141 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc717) + %dv_147 = tt.dot %dv_146, %do_145, %dv_103, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc718) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc719) + %Di_148 = tt.addptr %Di, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc719) + %Di_149 = tt.load %Di_148, %lse_111 : tensor<64x!tt.ptr> loc(#loc720) + %dpT = tt.trans %do_145 {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc721) + %dpT_150 = tt.dot %v, %dpT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc722) + %dsT = tt.expand_dims %Di_149 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc723) + %dsT_151 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc724) + %dsT_152 = arith.subf %dpT_150, %dsT_151 : tensor<128x64xf32> loc(#loc724) + %dsT_153 = arith.mulf %pT_141, %dsT_152 : tensor<128x64xf32> loc(#loc725) + %grad_scores = arith.select %qT_109, %dsT_153, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc726) + %dsT_154 = arith.select %tmp42, %grad_scores, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc727) + %dk_155 = arith.truncf %dsT_154 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc728) + %dk_156 = tt.trans %qT_110 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc729) + %dk_157 = tt.dot %dk_155, %dk_156, %dk_102, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc730) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc888) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc889) + %cur_block_158 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc890) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc891) + %next_block_159 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_61 : i32 loc(#loc892) + %next_block_160 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc893) + %next_block_161 = tt.load %next_block_160, %next_block_159 evictionPolicy = evict_last : !tt.ptr loc(#loc894) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc895) + %needs_jump_162 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc896) + %needs_jump_163 = arith.cmpi eq, %needs_jump_162, %c0_i32 : i32 loc(#loc897) + %jump_to_block = arith.subi %next_block_161, %cur_block_158 : i32 loc(#loc898) + %jump_to_block_164 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc899) + %jump_to_block_165 = arith.subi %jump_to_block_164, %c64_i32 : i32 loc(#loc900) + %offset = arith.extui %needs_jump_163 : i1 to i32 loc(#loc901) + %offset_166 = arith.muli %jump_to_block_165, %offset : i32 loc(#loc901) + %offset_167 = arith.subi %c1_i32, %offset : i32 loc(#loc902) + %offset_168 = arith.muli %offset_167, %c64_i32 : i32 loc(#loc903) + %offset_169 = arith.addi %offset_166, %offset_168 : i32 loc(#loc904) + %qT_ptrs_170 = arith.muli %offset_169, %c4096_i32 : i32 loc(#loc732) + %qT_ptrs_171 = tt.splat %qT_ptrs_170 : i32 -> tensor<128x64xi32> loc(#loc733) + %qT_ptrs_172 = tt.addptr %qT_ptrs_105, %qT_ptrs_171 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc733) + %do_ptrs_173 = arith.muli %offset_169, %c128_i32 : i32 loc(#loc734) + %do_ptrs_174 = tt.splat %do_ptrs_173 : i32 -> tensor<64x128xi32> loc(#loc735) + %do_ptrs_175 = tt.addptr %do_ptrs_106, %do_ptrs_174 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc735) + %offs_m1_176 = tt.splat %offset_169 : i32 -> tensor<64xi32> loc(#loc736) + %offs_m1_177 = arith.addi %offs_m1_104, %offs_m1_176 : tensor<64xi32> loc(#loc736) + scf.yield %dk_157, %dv_147, %offs_m1_177, %qT_ptrs_172, %do_ptrs_175 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc590) + } loc(#loc943) + %q_indices_82 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset : !tt.ptr, i32 loc(#loc591) + %q_start_83 = tt.load %q_indices_82 : !tt.ptr loc(#loc592) + %q_start_84 = arith.muli %q_start_83, %c128_i32 : i32 loc(#loc593) + %sparse_q_num_blocks_85 = tt.addptr %arg_FULL_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc594) + %sparse_q_num_blocks_86 = tt.load %sparse_q_num_blocks_85 : !tt.ptr loc(#loc595) + %offs_m1_87 = tt.splat %q_start_84 : i32 -> tensor<64xi32> loc(#loc596) + %offs_m1_88 = arith.addi %offs_m1_87, %offs_m1 : tensor<64xi32> loc(#loc596) + %qT_ptrs_89 = tt.expand_dims %offs_m1_88 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc737) + %qT_ptrs_90 = arith.muli %qT_ptrs_89, %cst_0 : tensor<1x64xi32> loc(#loc738) + %qT_ptrs_91 = tt.addptr %qT_ptrs_65, %qT_ptrs_90 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc739) + %qT_ptrs_92 = tt.broadcast %qT_ptrs_91 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc740) + %qT_ptrs_93 = tt.addptr %qT_ptrs_92, %qT_ptrs_69 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc740) + %do_ptrs_94 = tt.expand_dims %offs_m1_88 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc741) + %do_ptrs_95 = arith.muli %do_ptrs_94, %cst : tensor<64x1xi32> loc(#loc742) + %do_ptrs_96 = tt.addptr %do_ptrs_72, %do_ptrs_95 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc743) + %do_ptrs_97 = tt.broadcast %do_ptrs_96 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc744) + %do_ptrs_98 = tt.addptr %do_ptrs_97, %do_ptrs_75 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc744) + %hi_99 = arith.muli %sparse_q_num_blocks_86, %c2_i32 : i32 loc(#loc745) + %hi_100 = arith.minsi %hi_99, %hi_79 : i32 loc(#loc746) + %do_ptrs_101:5 = scf.for %start_m = %c0_i32 to %hi_100 step %c1_i32 iter_args(%dk_102 = %do_ptrs_81#0, %dv_103 = %do_ptrs_81#1, %offs_m1_104 = %offs_m1_88, %qT_ptrs_105 = %qT_ptrs_93, %do_ptrs_106 = %do_ptrs_98) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.expand_dims %offs_m1_104 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc905) + %qT_107 = tt.splat %ks0 : i32 -> tensor<1x64xi32> loc(#loc906) + %qT_108 = arith.cmpi slt, %qT, %qT_107 : tensor<1x64xi32> loc(#loc906) + %qT_109 = tt.broadcast %qT_108 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc907) + %qT_110 = tt.load %qT_ptrs_105, %qT_109, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc907) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32> loc(#loc748) + %lse_111 = arith.cmpi slt, %offs_m1_104, %lse : tensor<64xi32> loc(#loc748) + %lse_112 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc749) + %lse_113 = tt.addptr %lse_112, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc749) + %lse_114 = tt.load %lse_113, %lse_111 : tensor<64x!tt.ptr> loc(#loc750) + %lse_115 = arith.cmpf oeq, %lse_114, %cst_5 : tensor<64xf32> loc(#loc751) + %lse_116 = arith.select %lse_115, %cst_4, %lse_114 : tensor<64xi1>, tensor<64xf32> loc(#loc752) + %qkT = tt.dot %k_40, %qT_110, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc753) + %qkT_117 = arith.mulf %qkT, %cst_14 : tensor<128x64xf32> loc(#loc754) + %post_mod_scores = arith.select %qT_109, %qkT_117, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc755) + %post_mod_scores_118 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32> loc(#loc756) + %pT = tt.expand_dims %lse_116 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc757) + %pT_119 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc758) + %pT_120 = arith.subf %post_mod_scores_118, %pT_119 : tensor<128x64xf32> loc(#loc758) + %pT_121 = math.exp2 %pT_120 : tensor<128x64xf32> loc(#loc759) + %do = tt.expand_dims %offs_m1_104 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc908) + %do_122 = tt.splat %ks0 : i32 -> tensor<64x1xi32> loc(#loc909) + %do_123 = arith.cmpi slt, %do, %do_122 : tensor<64x1xi32> loc(#loc909) + %do_124 = tt.broadcast %do_123 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc910) + %do_125 = tt.load %do_ptrs_106, %do_124, %cst_3 : tensor<64x128x!tt.ptr> loc(#loc910) + %dv_126 = arith.truncf %pT_121 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc761) + %dv_127 = tt.dot %dv_126, %do_125, %dv_103, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc762) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc763) + %Di_128 = tt.addptr %Di, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc763) + %Di_129 = tt.load %Di_128, %lse_111 : tensor<64x!tt.ptr> loc(#loc764) + %dpT = tt.trans %do_125 {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc765) + %dpT_130 = tt.dot %v, %dpT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc766) + %dsT = tt.expand_dims %Di_129 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc767) + %dsT_131 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc768) + %dsT_132 = arith.subf %dpT_130, %dsT_131 : tensor<128x64xf32> loc(#loc768) + %dsT_133 = arith.mulf %pT_121, %dsT_132 : tensor<128x64xf32> loc(#loc769) + %grad_scores = arith.select %qT_109, %dsT_133, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc770) + %dk_134 = arith.truncf %grad_scores : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc771) + %dk_135 = tt.trans %qT_110 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc772) + %dk_136 = tt.dot %dk_134, %dk_135, %dk_102, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc773) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc911) + %cur_block = tt.addptr %q_indices_82, %cur_block_idx : !tt.ptr, i32 loc(#loc912) + %cur_block_137 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc913) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc914) + %next_block_138 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_86 : i32 loc(#loc915) + %next_block_139 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc916) + %next_block_140 = tt.load %next_block_139, %next_block_138 evictionPolicy = evict_last : !tt.ptr loc(#loc917) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc918) + %needs_jump_141 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc919) + %needs_jump_142 = arith.cmpi eq, %needs_jump_141, %c0_i32 : i32 loc(#loc920) + %jump_to_block = arith.subi %next_block_140, %cur_block_137 : i32 loc(#loc921) + %jump_to_block_143 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc922) + %jump_to_block_144 = arith.subi %jump_to_block_143, %c64_i32 : i32 loc(#loc923) + %offset = arith.extui %needs_jump_142 : i1 to i32 loc(#loc924) + %offset_145 = arith.muli %jump_to_block_144, %offset : i32 loc(#loc924) + %offset_146 = arith.subi %c1_i32, %offset : i32 loc(#loc925) + %offset_147 = arith.muli %offset_146, %c64_i32 : i32 loc(#loc926) + %offset_148 = arith.addi %offset_145, %offset_147 : i32 loc(#loc927) + %qT_ptrs_149 = arith.muli %offset_148, %c4096_i32 : i32 loc(#loc775) + %qT_ptrs_150 = tt.splat %qT_ptrs_149 : i32 -> tensor<128x64xi32> loc(#loc776) + %qT_ptrs_151 = tt.addptr %qT_ptrs_105, %qT_ptrs_150 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc776) + %do_ptrs_152 = arith.muli %offset_148, %c128_i32 : i32 loc(#loc777) + %do_ptrs_153 = tt.splat %do_ptrs_152 : i32 -> tensor<64x128xi32> loc(#loc778) + %do_ptrs_154 = tt.addptr %do_ptrs_106, %do_ptrs_153 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc778) + %offs_m1_155 = tt.splat %offset_148 : i32 -> tensor<64xi32> loc(#loc779) + %offs_m1_156 = arith.addi %offs_m1_104, %offs_m1_155 : tensor<64xi32> loc(#loc779) + scf.yield %dk_136, %dv_127, %offs_m1_156, %qT_ptrs_151, %do_ptrs_154 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc598) + } loc(#loc944) + scf.yield %do_ptrs_101#1, %do_ptrs_101#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc294) + } loc(#loc661) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc599) + %dv_ptrs_45 = tt.addptr %dv_ptrs, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc599) + %dv_ptrs_46 = tt.broadcast %dv_ptrs_45 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc600) + %dv_ptrs_47 = tt.addptr %dv_ptrs_46, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc600) + %11 = arith.cmpi slt, %ptr_34, %cst_20 : tensor<1x128xi32> loc(#loc297) + %12 = tt.broadcast %11 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc298) + %13 = arith.andi %k_39, %12 : tensor<128x128xi1> loc(#loc298) + %14 = arith.truncf %dk#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc299) + tt.store %dv_ptrs_47, %14, %13 : tensor<128x128x!tt.ptr> loc(#loc299) + %dk_48 = arith.mulf %dk#1, %cst_21 : tensor<128x128xf32> loc(#loc601) + %15 = tt.splat %k_adj : i32 -> tensor<1x128xi32> loc(#loc301) + %16 = arith.addi %ptr_34, %15 : tensor<1x128xi32> loc(#loc301) + %17 = tt.broadcast %16 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc302) + %18 = tt.broadcast %ptr_31 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc302) + %19 = arith.addi %17, %18 : tensor<128x128xi32> loc(#loc302) + %20 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc303) + %21 = tt.addptr %20, %19 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc303) + %22 = arith.truncf %dk_48 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc304) + tt.store %21, %22, %k_39 : tensor<128x128x!tt.ptr> loc(#loc304) + } loc(#loc29) + tt.return loc(#loc305) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":103:9) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":94:54) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:74) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:66) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:100) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:91) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:82) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:59) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":97:111) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":100:58) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":111:24) +#loc13 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":112:36) +#loc15 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":113:34) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":115:27) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":116:28) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:25) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":124:59) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:50) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:37) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":128:61) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":131:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":132:9) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":133:10) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":136:26) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:14) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:7) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":140:24) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:29) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:54) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":144:44) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":145:35) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":155:83) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:30) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:52) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:40) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":158:63) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:32) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:55) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:42) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":159:66) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:30) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:35) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:46) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":161:56) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":163:17) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":164:19) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":167:19) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":168:21) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":169:25) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":174:36) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":175:29) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:27) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":178:107) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:38) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:20) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:56) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":789:49) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:52) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:23) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":179:111) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:58) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:34) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":188:25) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:33) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":189:26) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:30) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":190:50) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":191:18) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":195:30) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:27) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":196:41) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:53) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":197:39) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:42) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":199:29) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:26) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":207:12) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:37) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:18) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:56) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":390:49) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:18) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":391:49) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:43) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:90) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:101) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":395:63) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":397:28) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:41) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":458:105) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":405:12) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:52) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":795:23) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":459:19) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":461:14) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":762:21) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":464:46) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":467:46) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":476:79) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":481:22) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":483:23) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":484:22) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":485:23) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":486:22) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":487:22) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":488:24) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":489:23) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:70) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:79) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:91) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:99) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:102) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":492:119) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:70) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:79) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:91) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:99) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:102) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":494:119) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":495:25) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":496:24) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":497:23) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":498:23) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":503:69) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":506:27) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:39) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":507:21) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":510:104) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":512:20) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:22) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:19) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":513:14) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":520:71) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":531:43) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":533:15) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:30) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":535:21) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":752:33) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":411:64) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:38) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":753:24) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:109) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:113) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:55) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":754:25) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:30) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:35) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":755:60) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:34) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:48) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":756:63) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:29) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:47) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:61) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":757:42) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:28) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":414:19) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":415:19) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":417:19) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":417:8) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":214:39) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:31) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":215:45) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:62) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":216:43) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":218:33) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":226:16) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:24) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":231:56) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":232:14) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:87) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:69) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":236:30) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":252:25) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":253:29) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":256:107) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":257:107) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":262:30) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:32) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":263:51) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:34) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:56) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:44) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":266:67) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:36) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:59) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:46) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":267:70) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:34) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:39) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:50) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":269:60) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":271:21) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":272:23) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":275:25) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":276:29) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":282:81) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":286:32) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:30) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":287:43) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:55) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":288:42) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:45) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":290:32) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:26) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":298:16) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:37) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:18) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:56) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":583:49) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:27) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:38) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:19) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":584:51) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:42) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:87) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:98) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":590:61) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":592:28) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":651:105) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":600:12) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:52) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:28) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":656:22) +#loc228 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:26) +#loc229 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":657:46) +#loc230 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":658:20) +#loc231 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":660:15) +#loc232 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":662:46) +#loc233 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":665:46) +#loc234 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":674:78) +#loc235 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":679:24) +#loc236 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":681:25) +#loc237 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":682:24) +#loc238 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":683:25) +#loc239 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":684:24) +#loc240 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":685:24) +#loc241 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":686:25) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":687:24) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:70) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:79) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:91) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:99) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:102) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":690:119) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:70) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:79) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:91) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:99) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:102) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":692:119) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":693:25) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":694:24) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":695:24) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":696:24) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":700:69) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":703:27) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:44) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:40) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":704:22) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":797:41) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":705:99) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:24) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":708:43) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:29) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":712:21) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:29) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":714:20) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:25) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:22) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":715:16) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":723:70) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":737:45) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:24) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:52) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":739:43) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":605:62) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:28) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":608:19) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:28) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":609:19) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":610:19) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":610:8) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":306:41) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:34) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":307:47) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:64) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":308:46) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":310:36) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":318:20) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":303:12) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:23) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":323:55) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:71) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:61) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":332:30) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":334:14) +#loc301 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:55) +#loc302 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:69) +#loc303 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:29) +#loc304 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":345:99) +#loc305 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/7e/c7enonvqdt6bf22ypfamsbxi3rx5c5zj47uvqqhcgsolrpye7inz.py":139:4) +#loc331 = loc("HQ"(#loc2)) +#loc332 = loc("pid"(#loc12)) +#loc333 = loc("NUM_KV_BLOCKS"(#loc14)) +#loc334 = loc("NUM_Q_BLOCKS"(#loc16)) +#loc335 = loc("off_zq"(#loc17)) +#loc336 = loc("off_hkv"(#loc18)) +#loc337 = loc("k_adj"(#loc19)) +#loc338 = loc("k_adj"(#loc20)) +#loc339 = loc("dv_adj"(#loc21)) +#loc340 = loc("dv_adj"(#loc22)) +#loc341 = loc("dv_adj"(#loc23)) +#loc342 = loc("K"(#loc24)) +#loc343 = loc("V"(#loc25)) +#loc344 = loc("DV"(#loc26)) +#loc345 = loc("offs_k"(#loc27)) +#loc346 = loc("off_pid"(#loc30)) +#loc347 = loc("off_hq2"(#loc31)) +#loc348 = loc("off_hq2"(#loc32)) +#loc349 = loc("off_hq2"(#loc33)) +#loc350 = loc("start_m2_block"(#loc34)) +#loc351 = loc("sparse_kv_idx_offset"(#loc35)) +#loc352 = loc("q_adj2"(#loc36)) +#loc353 = loc("q_adj2"(#loc37)) +#loc354 = loc("q_adj2"(#loc38)) +#loc355 = loc("q_adj2"(#loc39)) +#loc356 = loc("do_adj2"(#loc40)) +#loc357 = loc("do_adj2"(#loc41)) +#loc358 = loc("do_adj2"(#loc42)) +#loc359 = loc("do_adj2"(#loc43)) +#loc360 = loc("off_chz2"(#loc44)) +#loc361 = loc("off_chz2"(#loc45)) +#loc362 = loc("off_chz2"(#loc46)) +#loc363 = loc("off_chz2"(#loc47)) +#loc364 = loc("Q2"(#loc48)) +#loc365 = loc("DO2"(#loc49)) +#loc366 = loc("DQ2"(#loc50)) +#loc367 = loc("LSE2"(#loc51)) +#loc368 = loc("DELTA2"(#loc52)) +#loc369 = loc("start_m2"(#loc53)) +#loc370 = loc("offs_m2"(#loc54)) +#loc371 = loc("ptr"(#loc55)) +#loc372 = loc("q"(#loc56)) +#loc373 = loc("ptr"(#loc57)) +#loc374 = loc("ptr"(#loc58)) +#loc375 = loc("ptr"(#loc59)) +#loc376 = loc("ptr"(#loc60)) +#loc377 = loc("do"(#loc63)) +#loc378 = loc("Di"(#loc64)) +#loc379 = loc("Di"(#loc65)) +#loc380 = loc("Di"(#loc66)) +#loc381 = loc("lse"(#loc67)) +#loc382 = loc("lse"(#loc68)) +#loc383 = loc("lse"(#loc69)) +#loc384 = loc("lse"(#loc70)) +#loc385 = loc("lse"(#loc71)) +#loc386 = loc("kv_indices"(#loc72)) +#loc387 = loc("kv_start"(#loc73)) +#loc388 = loc("kv_start"(#loc74)) +#loc389 = loc("sparse_kv_num_blocks"(#loc75)) +#loc390 = loc("sparse_kv_num_blocks"(#loc76)) +#loc391 = loc("offs_n2"(#loc77)) +#loc392 = loc("offs_n2"(#loc78)) +#loc393 = loc("kT_ptrs"(#loc79)) +#loc394 = loc("dq"(#loc80)) +#loc395 = loc("kT_ptrs"(#loc81)) +#loc396 = loc("kT_ptrs"(#loc82)) +#loc397 = loc("kT_ptrs"(#loc83)) +#loc398 = loc("kT_ptrs"(#loc84)) +#loc399 = loc("vT_ptrs"(#loc85)) +#loc400 = loc("vT_ptrs"(#loc86)) +#loc401 = loc("hi"(#loc87)) +#loc402 = loc("hi"(#loc88)) +#loc403 = loc("hi"(#loc89)) +#loc404 = loc("hi"(#loc90)) +#loc405 = loc("dq"(#loc91)) +#loc406 = loc("kT"(#loc93)) +#loc407 = loc("dq"(#loc94)) +#loc408 = loc("qk"(#loc97)) +#loc409 = loc("qk"(#loc98)) +#loc410 = loc("n"(#loc100)) +#loc411 = loc("m"(#loc101)) +#loc412 = loc("post_mod_scores"(#loc102)) +#loc413 = loc("tmp3"(#loc103)) +#loc414 = loc("tmp5"(#loc104)) +#loc415 = loc("tmp6"(#loc105)) +#loc416 = loc("tmp7"(#loc106)) +#loc417 = loc("tmp8"(#loc107)) +#loc418 = loc("tmp9"(#loc108)) +#loc419 = loc("tmp10"(#loc109)) +#loc420 = loc("tmp11"(#loc110)) +#loc421 = loc("tmp14"(#loc111)) +#loc422 = loc("tmp14"(#loc112)) +#loc423 = loc("tmp14"(#loc113)) +#loc424 = loc("tmp14"(#loc114)) +#loc425 = loc("tmp14"(#loc115)) +#loc426 = loc("tmp14"(#loc116)) +#loc427 = loc("tmp16"(#loc117)) +#loc428 = loc("tmp16"(#loc118)) +#loc429 = loc("tmp16"(#loc119)) +#loc430 = loc("tmp16"(#loc120)) +#loc431 = loc("tmp16"(#loc121)) +#loc432 = loc("tmp16"(#loc122)) +#loc433 = loc("tmp17"(#loc123)) +#loc434 = loc("tmp18"(#loc124)) +#loc435 = loc("tmp19"(#loc125)) +#loc436 = loc("tmp20"(#loc126)) +#loc437 = loc("post_mod_scores"(#loc127)) +#loc438 = loc("post_mod_scores"(#loc128)) +#loc439 = loc("p"(#loc129)) +#loc440 = loc("p"(#loc130)) +#loc441 = loc("vT"(#loc131)) +#loc442 = loc("dp"(#loc132)) +#loc443 = loc("ds"(#loc133)) +#loc444 = loc("ds"(#loc134)) +#loc445 = loc("ds"(#loc135)) +#loc446 = loc("grad_scores"(#loc136)) +#loc447 = loc("ds"(#loc137)) +#loc448 = loc("ds"(#loc138)) +#loc449 = loc("dq"(#loc139)) +#loc450 = loc("dq"(#loc140)) +#loc451 = loc("cur_block_idx"(#loc141)) +#loc452 = loc("offset"(#loc142)) +#loc453 = loc("cur_block"(#loc143)) +#loc454 = loc("cur_block"(#loc144)) +#loc455 = loc("next_block"(#loc145)) +#loc456 = loc("next_block"(#loc146)) +#loc457 = loc("next_block"(#loc147)) +#loc458 = loc("next_block"(#loc148)) +#loc459 = loc("needs_jump"(#loc149)) +#loc460 = loc("needs_jump"(#loc150)) +#loc461 = loc("needs_jump"(#loc151)) +#loc462 = loc("jump_to_block"(#loc152)) +#loc463 = loc("jump_to_block"(#loc153)) +#loc464 = loc("jump_to_block"(#loc154)) +#loc465 = loc("offset"(#loc155)) +#loc466 = loc("offset"(#loc156)) +#loc467 = loc("offset"(#loc157)) +#loc468 = loc("offset"(#loc158)) +#loc469 = loc("kT_ptrs"(#loc159)) +#loc470 = loc("kT_ptrs"(#loc160)) +#loc471 = loc("vT_ptrs"(#loc161)) +#loc472 = loc("offs_n2"(#loc162)) +#loc473 = loc("kv_indices"(#loc164)) +#loc474 = loc("kv_start"(#loc165)) +#loc475 = loc("kv_start"(#loc166)) +#loc476 = loc("sparse_kv_num_blocks"(#loc167)) +#loc477 = loc("sparse_kv_num_blocks"(#loc168)) +#loc478 = loc("offs_n2"(#loc169)) +#loc479 = loc("dq"(#loc170)) +#loc480 = loc("dq_ptrs"(#loc171)) +#loc481 = loc("dq_ptrs"(#loc172)) +#loc482 = loc("dq"(#loc173)) +#loc483 = loc("start_n1"(#loc177)) +#loc484 = loc("offs_n1"(#loc178)) +#loc485 = loc("k"(#loc179)) +#loc486 = loc("v"(#loc180)) +#loc487 = loc("dv"(#loc181)) +#loc488 = loc("off_hq1"(#loc182)) +#loc489 = loc("off_hq1"(#loc183)) +#loc490 = loc("q_adj1"(#loc184)) +#loc491 = loc("q_adj1"(#loc185)) +#loc492 = loc("q_adj1"(#loc186)) +#loc493 = loc("q_adj1"(#loc187)) +#loc494 = loc("do_adj1"(#loc188)) +#loc495 = loc("do_adj1"(#loc189)) +#loc496 = loc("do_adj1"(#loc190)) +#loc497 = loc("do_adj1"(#loc191)) +#loc498 = loc("off_chz1"(#loc192)) +#loc499 = loc("off_chz1"(#loc193)) +#loc500 = loc("off_chz1"(#loc194)) +#loc501 = loc("off_chz1"(#loc195)) +#loc502 = loc("Q1"(#loc196)) +#loc503 = loc("DO1"(#loc197)) +#loc504 = loc("LSE1"(#loc198)) +#loc505 = loc("DELTA1"(#loc199)) +#loc506 = loc("sparse_q_idx_offset"(#loc200)) +#loc507 = loc("q_indices"(#loc201)) +#loc508 = loc("q_start"(#loc202)) +#loc509 = loc("q_start"(#loc203)) +#loc510 = loc("sparse_q_num_blocks"(#loc204)) +#loc511 = loc("sparse_q_num_blocks"(#loc205)) +#loc512 = loc("offs_m1"(#loc206)) +#loc513 = loc("offs_m1"(#loc207)) +#loc514 = loc("qT_ptrs"(#loc208)) +#loc515 = loc("qT_ptrs"(#loc210)) +#loc516 = loc("qT_ptrs"(#loc211)) +#loc517 = loc("qT_ptrs"(#loc212)) +#loc518 = loc("qT_ptrs"(#loc213)) +#loc519 = loc("do_ptrs"(#loc214)) +#loc520 = loc("do_ptrs"(#loc215)) +#loc521 = loc("do_ptrs"(#loc216)) +#loc522 = loc("do_ptrs"(#loc217)) +#loc523 = loc("hi"(#loc218)) +#loc524 = loc("hi"(#loc219)) +#loc525 = loc("hi"(#loc220)) +#loc526 = loc("hi"(#loc221)) +#loc527 = loc("dk"(#loc222)) +#loc528 = loc("qT"(#loc223)) +#loc529 = loc(callsite(#loc224 at #loc209)) +#loc530 = loc("lse"(#loc225)) +#loc531 = loc("lse"(#loc226)) +#loc532 = loc("lse"(#loc227)) +#loc533 = loc("lse"(#loc228)) +#loc534 = loc("lse"(#loc229)) +#loc535 = loc("qkT"(#loc230)) +#loc536 = loc("qkT"(#loc231)) +#loc537 = loc("m"(#loc232)) +#loc538 = loc("n"(#loc233)) +#loc539 = loc("post_mod_scores"(#loc234)) +#loc540 = loc("tmp25"(#loc235)) +#loc541 = loc("tmp27"(#loc236)) +#loc542 = loc("tmp28"(#loc237)) +#loc543 = loc("tmp29"(#loc238)) +#loc544 = loc("tmp30"(#loc239)) +#loc545 = loc("tmp31"(#loc240)) +#loc546 = loc("tmp32"(#loc241)) +#loc547 = loc("tmp33"(#loc242)) +#loc548 = loc("tmp36"(#loc243)) +#loc549 = loc("tmp36"(#loc244)) +#loc550 = loc("tmp36"(#loc245)) +#loc551 = loc("tmp36"(#loc246)) +#loc552 = loc("tmp36"(#loc247)) +#loc553 = loc("tmp36"(#loc248)) +#loc554 = loc("tmp38"(#loc249)) +#loc555 = loc("tmp38"(#loc250)) +#loc556 = loc("tmp38"(#loc251)) +#loc557 = loc("tmp38"(#loc252)) +#loc558 = loc("tmp38"(#loc253)) +#loc559 = loc("tmp38"(#loc254)) +#loc560 = loc("tmp39"(#loc255)) +#loc561 = loc("tmp40"(#loc256)) +#loc562 = loc("tmp41"(#loc257)) +#loc563 = loc("tmp42"(#loc258)) +#loc564 = loc("post_mod_scores"(#loc259)) +#loc565 = loc("post_mod_scores"(#loc260)) +#loc566 = loc("pT"(#loc261)) +#loc567 = loc("pT"(#loc262)) +#loc568 = loc("pT"(#loc263)) +#loc569 = loc("do"(#loc265)) +#loc570 = loc("dv"(#loc266)) +#loc571 = loc("dv"(#loc267)) +#loc572 = loc("Di"(#loc268)) +#loc573 = loc("Di"(#loc269)) +#loc574 = loc("dpT"(#loc270)) +#loc575 = loc("dpT"(#loc271)) +#loc576 = loc("dsT"(#loc272)) +#loc577 = loc("dsT"(#loc273)) +#loc578 = loc("dsT"(#loc274)) +#loc579 = loc("grad_scores"(#loc275)) +#loc580 = loc("dsT"(#loc276)) +#loc581 = loc("dk"(#loc277)) +#loc582 = loc("dk"(#loc278)) +#loc583 = loc("dk"(#loc279)) +#loc584 = loc("offset"(#loc280)) +#loc585 = loc("qT_ptrs"(#loc281)) +#loc586 = loc("qT_ptrs"(#loc282)) +#loc587 = loc("do_ptrs"(#loc283)) +#loc588 = loc("do_ptrs"(#loc284)) +#loc589 = loc("offs_m1"(#loc285)) +#loc590 = loc(callsite(#loc286 at #loc209)) +#loc591 = loc("q_indices"(#loc287)) +#loc592 = loc("q_start"(#loc288)) +#loc593 = loc("q_start"(#loc289)) +#loc594 = loc("sparse_q_num_blocks"(#loc290)) +#loc595 = loc("sparse_q_num_blocks"(#loc291)) +#loc596 = loc("offs_m1"(#loc292)) +#loc597 = loc(callsite(#loc224 at #loc293)) +#loc598 = loc(callsite(#loc286 at #loc293)) +#loc599 = loc("dv_ptrs"(#loc295)) +#loc600 = loc("dv_ptrs"(#loc296)) +#loc601 = loc("dk"(#loc300)) +#loc602 = loc(callsite(#loc13 at #loc333)) +#loc603 = loc(callsite(#loc15 at #loc333)) +#loc604 = loc(callsite(#loc13 at #loc334)) +#loc605 = loc(callsite(#loc15 at #loc334)) +#loc606 = loc(callsite(#loc371 at #loc372)) +#loc607 = loc(callsite(#loc373 at #loc372)) +#loc608 = loc(callsite(#loc374 at #loc372)) +#loc609 = loc(callsite(#loc375 at #loc372)) +#loc610 = loc(callsite(#loc376 at #loc372)) +#loc611 = loc(callsite(#loc61 at #loc372)) +#loc612 = loc(callsite(#loc62 at #loc372)) +#loc613 = loc(callsite(#loc373 at #loc377)) +#loc614 = loc(callsite(#loc374 at #loc377)) +#loc615 = loc(callsite(#loc376 at #loc377)) +#loc616 = loc(callsite(#loc62 at #loc377)) +#loc617 = loc(callsite(#loc393 at #loc394)) +#loc618 = loc(callsite(#loc395 at #loc394)) +#loc619 = loc(callsite(#loc396 at #loc394)) +#loc620 = loc(callsite(#loc397 at #loc394)) +#loc621 = loc(callsite(#loc398 at #loc394)) +#loc622 = loc(callsite(#loc399 at #loc394)) +#loc623 = loc(callsite(#loc400 at #loc394)) +#loc624 = loc(callsite(#loc401 at #loc394)) +#loc625 = loc(callsite(#loc402 at #loc394)) +#loc626 = loc(callsite(#loc403 at #loc394)) +#loc627 = loc(callsite(#loc404 at #loc394)) +#loc628 = loc("offs_n2"(#loc405)) +#loc629 = loc(callsite(#loc407 at #loc394)) +#loc630 = loc(callsite(#loc452 at #loc394)) +#loc631 = loc(callsite(#loc469 at #loc394)) +#loc632 = loc(callsite(#loc470 at #loc394)) +#loc633 = loc(callsite(#loc471 at #loc394)) +#loc634 = loc(callsite(#loc472 at #loc394)) +#loc635 = loc(callsite(#loc163 at #loc394)) +#loc636 = loc(callsite(#loc393 at #loc479)) +#loc637 = loc(callsite(#loc395 at #loc479)) +#loc638 = loc(callsite(#loc396 at #loc479)) +#loc639 = loc(callsite(#loc398 at #loc479)) +#loc640 = loc(callsite(#loc399 at #loc479)) +#loc641 = loc(callsite(#loc400 at #loc479)) +#loc642 = loc(callsite(#loc401 at #loc479)) +#loc643 = loc(callsite(#loc404 at #loc479)) +#loc644 = loc(callsite(#loc407 at #loc479)) +#loc645 = loc(callsite(#loc452 at #loc479)) +#loc646 = loc(callsite(#loc469 at #loc479)) +#loc647 = loc(callsite(#loc470 at #loc479)) +#loc648 = loc(callsite(#loc471 at #loc479)) +#loc649 = loc(callsite(#loc472 at #loc479)) +#loc650 = loc(callsite(#loc163 at #loc479)) +#loc651 = loc(callsite(#loc371 at #loc485)) +#loc652 = loc(callsite(#loc373 at #loc485)) +#loc653 = loc(callsite(#loc374 at #loc485)) +#loc654 = loc(callsite(#loc375 at #loc485)) +#loc655 = loc(callsite(#loc376 at #loc485)) +#loc656 = loc(callsite(#loc61 at #loc485)) +#loc657 = loc(callsite(#loc62 at #loc485)) +#loc658 = loc(callsite(#loc374 at #loc486)) +#loc659 = loc(callsite(#loc376 at #loc486)) +#loc660 = loc(callsite(#loc62 at #loc486)) +#loc661 = loc("dk"(#loc487)) +#loc662 = loc(callsite(#loc514 at #loc209)) +#loc663 = loc(callsite(#loc515 at #loc209)) +#loc664 = loc(callsite(#loc516 at #loc209)) +#loc665 = loc(callsite(#loc517 at #loc209)) +#loc666 = loc(callsite(#loc518 at #loc209)) +#loc667 = loc(callsite(#loc519 at #loc209)) +#loc668 = loc(callsite(#loc520 at #loc209)) +#loc669 = loc(callsite(#loc521 at #loc209)) +#loc670 = loc(callsite(#loc522 at #loc209)) +#loc671 = loc(callsite(#loc523 at #loc209)) +#loc672 = loc(callsite(#loc524 at #loc209)) +#loc673 = loc(callsite(#loc525 at #loc209)) +#loc674 = loc(callsite(#loc526 at #loc209)) +#loc675 = loc("dv"(#loc527)) +#loc676 = loc(callsite(#loc528 at #loc529)) +#loc677 = loc(callsite(#loc530 at #loc529)) +#loc678 = loc(callsite(#loc531 at #loc529)) +#loc679 = loc(callsite(#loc532 at #loc529)) +#loc680 = loc(callsite(#loc533 at #loc529)) +#loc681 = loc(callsite(#loc534 at #loc529)) +#loc682 = loc(callsite(#loc535 at #loc529)) +#loc683 = loc(callsite(#loc536 at #loc529)) +#loc684 = loc(callsite(#loc537 at #loc529)) +#loc685 = loc(callsite(#loc538 at #loc529)) +#loc686 = loc(callsite(#loc539 at #loc529)) +#loc687 = loc(callsite(#loc540 at #loc529)) +#loc688 = loc(callsite(#loc541 at #loc529)) +#loc689 = loc(callsite(#loc542 at #loc529)) +#loc690 = loc(callsite(#loc543 at #loc529)) +#loc691 = loc(callsite(#loc544 at #loc529)) +#loc692 = loc(callsite(#loc545 at #loc529)) +#loc693 = loc(callsite(#loc546 at #loc529)) +#loc694 = loc(callsite(#loc547 at #loc529)) +#loc695 = loc(callsite(#loc548 at #loc529)) +#loc696 = loc(callsite(#loc549 at #loc529)) +#loc697 = loc(callsite(#loc550 at #loc529)) +#loc698 = loc(callsite(#loc551 at #loc529)) +#loc699 = loc(callsite(#loc552 at #loc529)) +#loc700 = loc(callsite(#loc553 at #loc529)) +#loc701 = loc(callsite(#loc554 at #loc529)) +#loc702 = loc(callsite(#loc555 at #loc529)) +#loc703 = loc(callsite(#loc556 at #loc529)) +#loc704 = loc(callsite(#loc557 at #loc529)) +#loc705 = loc(callsite(#loc558 at #loc529)) +#loc706 = loc(callsite(#loc559 at #loc529)) +#loc707 = loc(callsite(#loc560 at #loc529)) +#loc708 = loc(callsite(#loc561 at #loc529)) +#loc709 = loc(callsite(#loc562 at #loc529)) +#loc710 = loc(callsite(#loc563 at #loc529)) +#loc711 = loc(callsite(#loc564 at #loc529)) +#loc712 = loc(callsite(#loc565 at #loc529)) +#loc713 = loc(callsite(#loc566 at #loc529)) +#loc714 = loc(callsite(#loc567 at #loc529)) +#loc715 = loc(callsite(#loc568 at #loc529)) +#loc716 = loc(callsite(#loc569 at #loc529)) +#loc717 = loc(callsite(#loc570 at #loc529)) +#loc718 = loc(callsite(#loc571 at #loc529)) +#loc719 = loc(callsite(#loc572 at #loc529)) +#loc720 = loc(callsite(#loc573 at #loc529)) +#loc721 = loc(callsite(#loc574 at #loc529)) +#loc722 = loc(callsite(#loc575 at #loc529)) +#loc723 = loc(callsite(#loc576 at #loc529)) +#loc724 = loc(callsite(#loc577 at #loc529)) +#loc725 = loc(callsite(#loc578 at #loc529)) +#loc726 = loc(callsite(#loc579 at #loc529)) +#loc727 = loc(callsite(#loc580 at #loc529)) +#loc728 = loc(callsite(#loc581 at #loc529)) +#loc729 = loc(callsite(#loc582 at #loc529)) +#loc730 = loc(callsite(#loc583 at #loc529)) +#loc731 = loc(callsite(#loc584 at #loc209)) +#loc732 = loc(callsite(#loc585 at #loc209)) +#loc733 = loc(callsite(#loc586 at #loc209)) +#loc734 = loc(callsite(#loc587 at #loc209)) +#loc735 = loc(callsite(#loc588 at #loc209)) +#loc736 = loc(callsite(#loc589 at #loc209)) +#loc737 = loc(callsite(#loc514 at #loc293)) +#loc738 = loc(callsite(#loc515 at #loc293)) +#loc739 = loc(callsite(#loc516 at #loc293)) +#loc740 = loc(callsite(#loc518 at #loc293)) +#loc741 = loc(callsite(#loc519 at #loc293)) +#loc742 = loc(callsite(#loc520 at #loc293)) +#loc743 = loc(callsite(#loc521 at #loc293)) +#loc744 = loc(callsite(#loc522 at #loc293)) +#loc745 = loc(callsite(#loc523 at #loc293)) +#loc746 = loc(callsite(#loc526 at #loc293)) +#loc747 = loc(callsite(#loc528 at #loc597)) +#loc748 = loc(callsite(#loc530 at #loc597)) +#loc749 = loc(callsite(#loc531 at #loc597)) +#loc750 = loc(callsite(#loc532 at #loc597)) +#loc751 = loc(callsite(#loc533 at #loc597)) +#loc752 = loc(callsite(#loc534 at #loc597)) +#loc753 = loc(callsite(#loc535 at #loc597)) +#loc754 = loc(callsite(#loc536 at #loc597)) +#loc755 = loc(callsite(#loc539 at #loc597)) +#loc756 = loc(callsite(#loc565 at #loc597)) +#loc757 = loc(callsite(#loc566 at #loc597)) +#loc758 = loc(callsite(#loc567 at #loc597)) +#loc759 = loc(callsite(#loc568 at #loc597)) +#loc760 = loc(callsite(#loc569 at #loc597)) +#loc761 = loc(callsite(#loc570 at #loc597)) +#loc762 = loc(callsite(#loc571 at #loc597)) +#loc763 = loc(callsite(#loc572 at #loc597)) +#loc764 = loc(callsite(#loc573 at #loc597)) +#loc765 = loc(callsite(#loc574 at #loc597)) +#loc766 = loc(callsite(#loc575 at #loc597)) +#loc767 = loc(callsite(#loc576 at #loc597)) +#loc768 = loc(callsite(#loc577 at #loc597)) +#loc769 = loc(callsite(#loc578 at #loc597)) +#loc770 = loc(callsite(#loc579 at #loc597)) +#loc771 = loc(callsite(#loc581 at #loc597)) +#loc772 = loc(callsite(#loc582 at #loc597)) +#loc773 = loc(callsite(#loc583 at #loc597)) +#loc774 = loc(callsite(#loc584 at #loc293)) +#loc775 = loc(callsite(#loc585 at #loc293)) +#loc776 = loc(callsite(#loc586 at #loc293)) +#loc777 = loc(callsite(#loc587 at #loc293)) +#loc778 = loc(callsite(#loc588 at #loc293)) +#loc779 = loc(callsite(#loc589 at #loc293)) +#loc780 = loc(callsite(#loc13 at #loc625)) +#loc781 = loc(callsite(#loc15 at #loc625)) +#loc782 = loc("kT_ptrs"(#loc628)) +#loc783 = loc(callsite(#loc406 at #loc629)) +#loc784 = loc(callsite(#loc408 at #loc629)) +#loc785 = loc(callsite(#loc409 at #loc629)) +#loc786 = loc(callsite(#loc410 at #loc629)) +#loc787 = loc(callsite(#loc411 at #loc629)) +#loc788 = loc(callsite(#loc412 at #loc629)) +#loc789 = loc(callsite(#loc413 at #loc629)) +#loc790 = loc(callsite(#loc414 at #loc629)) +#loc791 = loc(callsite(#loc415 at #loc629)) +#loc792 = loc(callsite(#loc416 at #loc629)) +#loc793 = loc(callsite(#loc417 at #loc629)) +#loc794 = loc(callsite(#loc418 at #loc629)) +#loc795 = loc(callsite(#loc419 at #loc629)) +#loc796 = loc(callsite(#loc420 at #loc629)) +#loc797 = loc(callsite(#loc421 at #loc629)) +#loc798 = loc(callsite(#loc422 at #loc629)) +#loc799 = loc(callsite(#loc423 at #loc629)) +#loc800 = loc(callsite(#loc424 at #loc629)) +#loc801 = loc(callsite(#loc425 at #loc629)) +#loc802 = loc(callsite(#loc426 at #loc629)) +#loc803 = loc(callsite(#loc427 at #loc629)) +#loc804 = loc(callsite(#loc428 at #loc629)) +#loc805 = loc(callsite(#loc429 at #loc629)) +#loc806 = loc(callsite(#loc430 at #loc629)) +#loc807 = loc(callsite(#loc431 at #loc629)) +#loc808 = loc(callsite(#loc432 at #loc629)) +#loc809 = loc(callsite(#loc433 at #loc629)) +#loc810 = loc(callsite(#loc434 at #loc629)) +#loc811 = loc(callsite(#loc435 at #loc629)) +#loc812 = loc(callsite(#loc436 at #loc629)) +#loc813 = loc(callsite(#loc437 at #loc629)) +#loc814 = loc(callsite(#loc438 at #loc629)) +#loc815 = loc(callsite(#loc439 at #loc629)) +#loc816 = loc(callsite(#loc440 at #loc629)) +#loc817 = loc(callsite(#loc441 at #loc629)) +#loc818 = loc(callsite(#loc442 at #loc629)) +#loc819 = loc(callsite(#loc443 at #loc629)) +#loc820 = loc(callsite(#loc444 at #loc629)) +#loc821 = loc(callsite(#loc445 at #loc629)) +#loc822 = loc(callsite(#loc446 at #loc629)) +#loc823 = loc(callsite(#loc447 at #loc629)) +#loc824 = loc(callsite(#loc448 at #loc629)) +#loc825 = loc(callsite(#loc449 at #loc629)) +#loc826 = loc(callsite(#loc450 at #loc629)) +#loc827 = loc(callsite(#loc451 at #loc630)) +#loc828 = loc(callsite(#loc453 at #loc630)) +#loc829 = loc(callsite(#loc454 at #loc630)) +#loc830 = loc(callsite(#loc455 at #loc630)) +#loc831 = loc(callsite(#loc456 at #loc630)) +#loc832 = loc(callsite(#loc457 at #loc630)) +#loc833 = loc(callsite(#loc458 at #loc630)) +#loc834 = loc(callsite(#loc459 at #loc630)) +#loc835 = loc(callsite(#loc460 at #loc630)) +#loc836 = loc(callsite(#loc461 at #loc630)) +#loc837 = loc(callsite(#loc462 at #loc630)) +#loc838 = loc(callsite(#loc463 at #loc630)) +#loc839 = loc(callsite(#loc464 at #loc630)) +#loc840 = loc(callsite(#loc465 at #loc630)) +#loc841 = loc(callsite(#loc466 at #loc630)) +#loc842 = loc(callsite(#loc467 at #loc630)) +#loc843 = loc(callsite(#loc468 at #loc630)) +#loc844 = loc(callsite(#loc406 at #loc644)) +#loc845 = loc(callsite(#loc408 at #loc644)) +#loc846 = loc(callsite(#loc409 at #loc644)) +#loc847 = loc(callsite(#loc412 at #loc644)) +#loc848 = loc(callsite(#loc438 at #loc644)) +#loc849 = loc(callsite(#loc439 at #loc644)) +#loc850 = loc(callsite(#loc440 at #loc644)) +#loc851 = loc(callsite(#loc441 at #loc644)) +#loc852 = loc(callsite(#loc442 at #loc644)) +#loc853 = loc(callsite(#loc443 at #loc644)) +#loc854 = loc(callsite(#loc444 at #loc644)) +#loc855 = loc(callsite(#loc445 at #loc644)) +#loc856 = loc(callsite(#loc446 at #loc644)) +#loc857 = loc(callsite(#loc448 at #loc644)) +#loc858 = loc(callsite(#loc449 at #loc644)) +#loc859 = loc(callsite(#loc450 at #loc644)) +#loc860 = loc(callsite(#loc451 at #loc645)) +#loc861 = loc(callsite(#loc453 at #loc645)) +#loc862 = loc(callsite(#loc454 at #loc645)) +#loc863 = loc(callsite(#loc455 at #loc645)) +#loc864 = loc(callsite(#loc456 at #loc645)) +#loc865 = loc(callsite(#loc457 at #loc645)) +#loc866 = loc(callsite(#loc458 at #loc645)) +#loc867 = loc(callsite(#loc459 at #loc645)) +#loc868 = loc(callsite(#loc460 at #loc645)) +#loc869 = loc(callsite(#loc461 at #loc645)) +#loc870 = loc(callsite(#loc462 at #loc645)) +#loc871 = loc(callsite(#loc463 at #loc645)) +#loc872 = loc(callsite(#loc464 at #loc645)) +#loc873 = loc(callsite(#loc465 at #loc645)) +#loc874 = loc(callsite(#loc466 at #loc645)) +#loc875 = loc(callsite(#loc467 at #loc645)) +#loc876 = loc(callsite(#loc468 at #loc645)) +#loc877 = loc(callsite(#loc13 at #loc672)) +#loc878 = loc(callsite(#loc15 at #loc672)) +#loc879 = loc("offs_m1"(#loc675)) +#loc880 = loc(callsite(#loc92 at #loc676)) +#loc881 = loc(callsite(#loc95 at #loc676)) +#loc882 = loc(callsite(#loc96 at #loc676)) +#loc883 = loc(callsite(#loc99 at #loc684)) +#loc884 = loc(callsite(#loc99 at #loc685)) +#loc885 = loc(callsite(#loc264 at #loc716)) +#loc886 = loc(callsite(#loc61 at #loc716)) +#loc887 = loc(callsite(#loc62 at #loc716)) +#loc888 = loc(callsite(#loc451 at #loc731)) +#loc889 = loc(callsite(#loc453 at #loc731)) +#loc890 = loc(callsite(#loc454 at #loc731)) +#loc891 = loc(callsite(#loc455 at #loc731)) +#loc892 = loc(callsite(#loc456 at #loc731)) +#loc893 = loc(callsite(#loc457 at #loc731)) +#loc894 = loc(callsite(#loc458 at #loc731)) +#loc895 = loc(callsite(#loc459 at #loc731)) +#loc896 = loc(callsite(#loc460 at #loc731)) +#loc897 = loc(callsite(#loc461 at #loc731)) +#loc898 = loc(callsite(#loc462 at #loc731)) +#loc899 = loc(callsite(#loc463 at #loc731)) +#loc900 = loc(callsite(#loc464 at #loc731)) +#loc901 = loc(callsite(#loc465 at #loc731)) +#loc902 = loc(callsite(#loc466 at #loc731)) +#loc903 = loc(callsite(#loc467 at #loc731)) +#loc904 = loc(callsite(#loc468 at #loc731)) +#loc905 = loc(callsite(#loc92 at #loc747)) +#loc906 = loc(callsite(#loc95 at #loc747)) +#loc907 = loc(callsite(#loc96 at #loc747)) +#loc908 = loc(callsite(#loc264 at #loc760)) +#loc909 = loc(callsite(#loc61 at #loc760)) +#loc910 = loc(callsite(#loc62 at #loc760)) +#loc911 = loc(callsite(#loc451 at #loc774)) +#loc912 = loc(callsite(#loc453 at #loc774)) +#loc913 = loc(callsite(#loc454 at #loc774)) +#loc914 = loc(callsite(#loc455 at #loc774)) +#loc915 = loc(callsite(#loc456 at #loc774)) +#loc916 = loc(callsite(#loc457 at #loc774)) +#loc917 = loc(callsite(#loc458 at #loc774)) +#loc918 = loc(callsite(#loc459 at #loc774)) +#loc919 = loc(callsite(#loc460 at #loc774)) +#loc920 = loc(callsite(#loc461 at #loc774)) +#loc921 = loc(callsite(#loc462 at #loc774)) +#loc922 = loc(callsite(#loc463 at #loc774)) +#loc923 = loc(callsite(#loc464 at #loc774)) +#loc924 = loc(callsite(#loc465 at #loc774)) +#loc925 = loc(callsite(#loc466 at #loc774)) +#loc926 = loc(callsite(#loc467 at #loc774)) +#loc927 = loc(callsite(#loc468 at #loc774)) +#loc928 = loc("vT_ptrs"(#loc782)) +#loc929 = loc(callsite(#loc92 at #loc783)) +#loc930 = loc(callsite(#loc95 at #loc783)) +#loc931 = loc(callsite(#loc96 at #loc783)) +#loc932 = loc(callsite(#loc99 at #loc786)) +#loc933 = loc(callsite(#loc99 at #loc787)) +#loc934 = loc(callsite(#loc96 at #loc817)) +#loc935 = loc(callsite(#loc92 at #loc844)) +#loc936 = loc(callsite(#loc95 at #loc844)) +#loc937 = loc(callsite(#loc96 at #loc844)) +#loc938 = loc(callsite(#loc96 at #loc851)) +#loc939 = loc("qT_ptrs"(#loc879)) +#loc940 = loc(callsite(#loc928 at #loc394)) +#loc941 = loc(callsite(#loc928 at #loc479)) +#loc942 = loc("do_ptrs"(#loc939)) +#loc943 = loc(callsite(#loc942 at #loc209)) +#loc944 = loc(callsite(#loc942 at #loc293)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.cubin b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.cubin new file mode 100644 index 0000000000000000000000000000000000000000..a218f7f1f34164d2962051a81314d070a5b9a826 Binary files /dev/null and b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.cubin differ diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.json new file mode 100644 index 0000000000000000000000000000000000000000..b921df4565401c77c5af376ec7d5acd1acfa9915 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.json @@ -0,0 +1 @@ +{"hash": "146a76e5887650fc7236bd81c5018c505fdc63f569fbea96947f8886701d58af", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 1, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 32, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_red_fused_mul_0"} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.ttgir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..5f62fda06580d9c5a5a1ca3e723d0c222e52ee03 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/CRVHNZMIOZIPY4RWXWA4KAMMKBP5YY7VNH56VFUUP6EIM4A5LCXQ/triton_red_fused_mul_0.ttgir @@ -0,0 +1,237 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":18:0) +#loc1 = loc(unknown) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":43:25) +#loc56 = loc("in_ptr0"(#loc)) +#loc57 = loc("in_ptr1"(#loc)) +#loc58 = loc("in_ptr2"(#loc)) +#loc59 = loc("out_ptr1"(#loc)) +#loc60 = loc("ks0"(#loc)) +#loc61 = loc("xnumel"(#loc)) +#loc62 = loc("r0_numel"(#loc)) +#loc99 = loc("tmp4"(#loc43)) +#loc118 = loc(callsite(#loc1 at #loc99)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_red_fused_mul_0(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %in_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr1"(#loc)), %in_ptr2: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr2"(#loc)), %out_ptr1: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr1"(#loc)), %ks0: i64 loc("ks0"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc)), %r0_numel: i32 {tt.divisibility = 16 : i32} loc("r0_numel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<0.693147182> : tensor<8x1xf32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<1.44269502> : tensor<8x1xf32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<0> : tensor<8x1xi64, #blocked> loc(#loc1) + %cst_2 = arith.constant dense<0> : tensor<8x1xi32, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<1> : tensor<8x1xi64, #blocked> loc(#loc1) + %cst_4 = arith.constant dense<1> : tensor<8x1xi64, #blocked1> loc(#loc1) + %c8_i32 = arith.constant 8 : i32 loc(#loc1) + %cst_5 = arith.constant dense<0.000000e+00> : tensor<8x128xbf16, #blocked1> loc(#loc1) + %c0_i64 = arith.constant 0 : i64 loc(#loc1) + %c1_i64 = arith.constant 1 : i64 loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<8x128xf32, #blocked1> loc(#loc1) + %cst_7 = arith.constant dense<0> : tensor<8x1xi32, #blocked1> loc(#loc1) + %cst_8 = arith.constant dense<128> : tensor<1x128xi32, #blocked1> loc(#loc1) + %cst_9 = arith.constant dense<128> : tensor<8x1xi64, #blocked1> loc(#loc1) + %cst_10 = arith.constant dense<4096> : tensor<8x1xi64, #blocked1> loc(#loc1) + %cst_11 = arith.constant dense<0> : tensor<8x1xi64, #blocked1> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc63) + %xoffset_12 = arith.muli %xoffset, %c8_i32 : i32 loc(#loc64) + %xindex = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc65) + %xindex_13 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc65) + %xindex_14 = tt.expand_dims %xindex {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1xi32, #blocked1> loc(#loc65) + %xindex_15 = tt.expand_dims %xindex_13 {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xi32, #blocked> loc(#loc65) + %xindex_16 = tt.splat %xoffset_12 : i32 -> tensor<8x1xi32, #blocked1> loc(#loc66) + %xindex_17 = tt.splat %xoffset_12 : i32 -> tensor<8x1xi32, #blocked> loc(#loc66) + %xindex_18 = arith.addi %xindex_16, %xindex_14 : tensor<8x1xi32, #blocked1> loc(#loc66) + %xindex_19 = arith.addi %xindex_17, %xindex_15 : tensor<8x1xi32, #blocked> loc(#loc66) + %xmask = tt.splat %xnumel : i32 -> tensor<8x1xi32, #blocked1> loc(#loc67) + %xmask_20 = tt.splat %xnumel : i32 -> tensor<8x1xi32, #blocked> loc(#loc67) + %xmask_21 = arith.cmpi slt, %xindex_18, %xmask : tensor<8x1xi32, #blocked1> loc(#loc67) + %xmask_22 = arith.cmpi slt, %xindex_19, %xmask_20 : tensor<8x1xi32, #blocked> loc(#loc67) + %r0_base = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc68) + %r0_base_23 = tt.expand_dims %r0_base {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> loc(#loc68) + %x0 = arith.extsi %xindex_18 : tensor<8x1xi32, #blocked1> to tensor<8x1xi64, #blocked1> loc(#loc69) + %x0_24 = arith.extsi %xindex_19 : tensor<8x1xi32, #blocked> to tensor<8x1xi64, #blocked> loc(#loc69) + %x0_25 = tt.splat %ks0 : i64 -> tensor<8x1xi64, #blocked1> loc(#loc69) + %x0_26 = tt.splat %ks0 : i64 -> tensor<8x1xi64, #blocked> loc(#loc69) + %x0_27 = arith.remsi %x0, %x0_25 : tensor<8x1xi64, #blocked1> loc(#loc69) + %x0_28 = arith.remsi %x0_24, %x0_26 : tensor<8x1xi64, #blocked> loc(#loc69) + %quot = arith.divsi %x0, %x0_25 : tensor<8x1xi64, #blocked1> loc(#loc108) + %quot_29 = arith.divsi %x0_24, %x0_26 : tensor<8x1xi64, #blocked> loc(#loc108) + %fixed = arith.cmpi ne, %x0_27, %cst_11 : tensor<8x1xi64, #blocked1> loc(#loc109) + %fixed_30 = arith.cmpi ne, %x0_28, %cst_1 : tensor<8x1xi64, #blocked> loc(#loc109) + %fixed_31 = arith.subi %quot, %cst_4 : tensor<8x1xi64, #blocked1> loc(#loc110) + %fixed_32 = arith.subi %quot_29, %cst_3 : tensor<8x1xi64, #blocked> loc(#loc110) + %fixed_33 = arith.select %fixed, %fixed_31, %quot : tensor<8x1xi1, #blocked1>, tensor<8x1xi64, #blocked1> loc(#loc111) + %fixed_34 = arith.select %fixed_30, %fixed_32, %quot_29 : tensor<8x1xi1, #blocked>, tensor<8x1xi64, #blocked> loc(#loc111) + %x1 = arith.cmpi slt, %xindex_18, %cst_7 : tensor<8x1xi32, #blocked1> loc(#loc112) + %x1_35 = arith.cmpi slt, %xindex_19, %cst_2 : tensor<8x1xi32, #blocked> loc(#loc112) + %x1_36 = arith.cmpi slt, %ks0, %c0_i64 : i64 loc(#loc113) + %x1_37 = tt.splat %x1_36 : i1 -> tensor<8x1xi1, #blocked1> loc(#loc114) + %x1_38 = tt.splat %x1_36 : i1 -> tensor<8x1xi1, #blocked> loc(#loc114) + %x1_39 = arith.cmpi ne, %x1, %x1_37 : tensor<8x1xi1, #blocked1> loc(#loc114) + %x1_40 = arith.cmpi ne, %x1_35, %x1_38 : tensor<8x1xi1, #blocked> loc(#loc114) + %x1_41 = arith.select %x1_39, %fixed_33, %quot : tensor<8x1xi1, #blocked1>, tensor<8x1xi64, #blocked1> loc(#loc115) + %x1_42 = arith.select %x1_40, %fixed_34, %quot_29 : tensor<8x1xi1, #blocked>, tensor<8x1xi64, #blocked> loc(#loc115) + %r0_mask = arith.cmpi slt, %r0_base_23, %cst_8 : tensor<1x128xi32, #blocked1> loc(#loc75) + %tmp0 = arith.muli %x1_41, %cst_9 : tensor<8x1xi64, #blocked1> loc(#loc76) + %tmp0_43 = arith.extsi %r0_base_23 : tensor<1x128xi32, #blocked1> to tensor<1x128xi64, #blocked1> loc(#loc77) + %tmp0_44 = tt.broadcast %tmp0_43 : tensor<1x128xi64, #blocked1> -> tensor<8x128xi64, #blocked1> loc(#loc77) + %tmp0_45 = tt.broadcast %tmp0 : tensor<8x1xi64, #blocked1> -> tensor<8x128xi64, #blocked1> loc(#loc77) + %tmp0_46 = arith.addi %tmp0_44, %tmp0_45 : tensor<8x128xi64, #blocked1> loc(#loc77) + %tmp0_47 = arith.muli %x0_27, %cst_10 : tensor<8x1xi64, #blocked1> loc(#loc78) + %tmp0_48 = tt.broadcast %tmp0_47 : tensor<8x1xi64, #blocked1> -> tensor<8x128xi64, #blocked1> loc(#loc79) + %tmp0_49 = arith.addi %tmp0_46, %tmp0_48 : tensor<8x128xi64, #blocked1> loc(#loc79) + %tmp0_50 = tt.splat %in_ptr0 : !tt.ptr -> tensor<8x128x!tt.ptr, #blocked1> loc(#loc80) + %tmp0_51 = tt.addptr %tmp0_50, %tmp0_49 : tensor<8x128x!tt.ptr, #blocked1>, tensor<8x128xi64, #blocked1> loc(#loc80) + %tmp0_52 = tt.broadcast %r0_mask : tensor<1x128xi1, #blocked1> -> tensor<8x128xi1, #blocked1> loc(#loc81) + %tmp0_53 = tt.broadcast %xmask_21 : tensor<8x1xi1, #blocked1> -> tensor<8x128xi1, #blocked1> loc(#loc81) + %tmp0_54 = arith.andi %tmp0_52, %tmp0_53 : tensor<8x128xi1, #blocked1> loc(#loc81) + %tmp0_55 = tt.load %tmp0_51, %tmp0_54, %cst_5 evictionPolicy = evict_first : tensor<8x128x!tt.ptr, #blocked1> loc(#loc82) + %tmp0_56 = arith.extf %tmp0_55 : tensor<8x128xbf16, #blocked1> to tensor<8x128xf32, #blocked1> loc(#loc83) + %tmp1 = arith.muli %x0_27, %cst_9 : tensor<8x1xi64, #blocked1> loc(#loc84) + %tmp1_57 = tt.broadcast %tmp1 : tensor<8x1xi64, #blocked1> -> tensor<8x128xi64, #blocked1> loc(#loc85) + %tmp1_58 = arith.addi %tmp0_44, %tmp1_57 : tensor<8x128xi64, #blocked1> loc(#loc85) + %tmp1_59 = arith.cmpi sle, %ks0, %c1_i64 : i64 loc(#loc86) + %tmp1_60 = arith.cmpi sgt, %ks0, %c1_i64 : i64 loc(#loc87) + %tmp1_61 = arith.extui %tmp1_60 : i1 to i64 loc(#loc88) + %tmp1_62 = arith.muli %ks0, %tmp1_61 : i64 loc(#loc88) + %tmp1_63 = arith.extui %tmp1_59 : i1 to i64 loc(#loc116) + %tmp1_64 = arith.addi %tmp1_63, %tmp1_62 : i64 loc(#loc89) + %tmp1_65 = tt.splat %tmp1_64 : i64 -> tensor<8x1xi64, #blocked1> loc(#loc91) + %tmp1_66 = tt.splat %tmp1_64 : i64 -> tensor<8x1xi64, #blocked> loc(#loc91) + %tmp1_67 = arith.muli %tmp0, %tmp1_65 : tensor<8x1xi64, #blocked1> loc(#loc91) + %tmp1_68 = tt.broadcast %tmp1_67 : tensor<8x1xi64, #blocked1> -> tensor<8x128xi64, #blocked1> loc(#loc92) + %tmp1_69 = arith.addi %tmp1_58, %tmp1_68 : tensor<8x128xi64, #blocked1> loc(#loc92) + %tmp1_70 = tt.splat %in_ptr1 : !tt.ptr -> tensor<8x128x!tt.ptr, #blocked1> loc(#loc93) + %tmp1_71 = tt.addptr %tmp1_70, %tmp1_69 : tensor<8x128x!tt.ptr, #blocked1>, tensor<8x128xi64, #blocked1> loc(#loc93) + %tmp1_72 = tt.load %tmp1_71, %tmp0_54, %cst_5 evictionPolicy = evict_first : tensor<8x128x!tt.ptr, #blocked1> loc(#loc94) + %tmp1_73 = arith.extf %tmp1_72 : tensor<8x128xbf16, #blocked1> to tensor<8x128xf32, #blocked1> loc(#loc95) + %tmp2 = arith.mulf %tmp0_56, %tmp1_73 : tensor<8x128xf32, #blocked1> loc(#loc96) + %tmp5 = arith.addf %tmp2, %cst_6 : tensor<8x128xf32, #blocked1> loc(#loc97) + %_tmp4 = arith.select %tmp0_54, %tmp5, %cst_6 : tensor<8x128xi1, #blocked1>, tensor<8x128xf32, #blocked1> loc(#loc98) + %tmp4 = "tt.reduce"(%_tmp4) <{axis = 1 : i32}> ({ + ^bb0(%tmp4_80: f32 loc(callsite(#loc1 at #loc99)), %tmp4_81: f32 loc(callsite(#loc1 at #loc99))): + %tmp4_82 = arith.addf %tmp4_80, %tmp4_81 : f32 loc(#loc119) + tt.reduce.return %tmp4_82 : f32 loc(#loc117) + }) : (tensor<8x128xf32, #blocked1>) -> tensor<8xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc117) + %tmp12 = ttg.convert_layout %tmp4 : tensor<8xf32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc100) + %tmp4_74 = tt.expand_dims %tmp12 {axis = 1 : i32} : tensor<8xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<8x1xf32, #blocked> loc(#loc101) + %tmp7 = arith.muli %x1_42, %tmp1_66 : tensor<8x1xi64, #blocked> loc(#loc102) + %tmp7_75 = arith.addi %x0_28, %tmp7 : tensor<8x1xi64, #blocked> loc(#loc103) + %tmp7_76 = tt.splat %in_ptr2 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> loc(#loc104) + %tmp7_77 = tt.addptr %tmp7_76, %tmp7_75 : tensor<8x1x!tt.ptr, #blocked>, tensor<8x1xi64, #blocked> loc(#loc104) + %tmp7_78 = tt.load %tmp7_77, %xmask_22 evictionPolicy = evict_last : tensor<8x1x!tt.ptr, #blocked> loc(#loc105) + %tmp9 = arith.mulf %tmp7_78, %cst : tensor<8x1xf32, #blocked> loc(#loc106) + %tmp11 = arith.mulf %tmp9, %cst_0 : tensor<8x1xf32, #blocked> loc(#loc107) + %tmp12_79 = arith.subf %tmp4_74, %tmp11 : tensor<8x1xf32, #blocked> loc(#loc100) + %0 = tt.splat %out_ptr1 : !tt.ptr -> tensor<8x1x!tt.ptr, #blocked> loc(#loc53) + %1 = tt.addptr %0, %xindex_19 : tensor<8x1x!tt.ptr, #blocked>, tensor<8x1xi32, #blocked> loc(#loc53) + tt.store %1, %tmp12_79, %xmask_22 : tensor<8x1x!tt.ptr, #blocked> loc(#loc54) + tt.return loc(#loc55) + } loc(#loc) +} loc(#loc) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":22:28) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":22:33) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":23:44) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":23:23) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":24:21) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":25:37) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":27:19) +#loc9 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":72:16) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":28:51) +#loc11 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:34) +#loc12 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:44) +#loc13 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":74:47) +#loc14 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:25) +#loc15 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:36) +#loc16 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:32) +#loc17 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/torch/_inductor/runtime/triton_helpers.py":75:47) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":33:29) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:45) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:41) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:55) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:50) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:34) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:70) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:60) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":37:122) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:45) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:41) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:73) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:99) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:90) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:81) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:65) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:58) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:50) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:34) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:106) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":38:168) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":39:22) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":41:23) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":42:48) +#loc42 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":291:36) +#loc44 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":261:15) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":50:19) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":43:28) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":44:39) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":44:35) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":44:30) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":44:87) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":47:18) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":49:19) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":51:25) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":51:37) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/sf/csff2q6dxp7qm2ikuhypvu3j3z3mf5mdwtqmhvido2uebefsj5ye.py":51:4) +#loc63 = loc("xoffset"(#loc2)) +#loc64 = loc("xoffset"(#loc3)) +#loc65 = loc("xindex"(#loc4)) +#loc66 = loc("xindex"(#loc5)) +#loc67 = loc("xmask"(#loc6)) +#loc68 = loc("r0_base"(#loc7)) +#loc69 = loc("x0"(#loc8)) +#loc70 = loc("quot"(#loc9)) +#loc71 = loc("x1"(#loc10)) +#loc72 = loc("fixed"(#loc11)) +#loc73 = loc("fixed"(#loc12)) +#loc74 = loc("fixed"(#loc13)) +#loc75 = loc("r0_mask"(#loc18)) +#loc76 = loc("tmp0"(#loc19)) +#loc77 = loc("tmp0"(#loc20)) +#loc78 = loc("tmp0"(#loc21)) +#loc79 = loc("tmp0"(#loc22)) +#loc80 = loc("tmp0"(#loc23)) +#loc81 = loc("tmp0"(#loc24)) +#loc82 = loc("tmp0"(#loc25)) +#loc83 = loc("tmp0"(#loc26)) +#loc84 = loc("tmp1"(#loc27)) +#loc85 = loc("tmp1"(#loc28)) +#loc86 = loc("tmp1"(#loc29)) +#loc87 = loc("tmp1"(#loc30)) +#loc88 = loc("tmp1"(#loc31)) +#loc89 = loc("tmp1"(#loc32)) +#loc90 = loc("tmp1"(#loc33)) +#loc91 = loc("tmp1"(#loc34)) +#loc92 = loc("tmp1"(#loc35)) +#loc93 = loc("tmp1"(#loc36)) +#loc94 = loc("tmp1"(#loc37)) +#loc95 = loc("tmp1"(#loc38)) +#loc96 = loc("tmp2"(#loc39)) +#loc97 = loc("tmp5"(#loc40)) +#loc98 = loc("_tmp4"(#loc41)) +#loc100 = loc("tmp12"(#loc45)) +#loc101 = loc("tmp4"(#loc46)) +#loc102 = loc("tmp7"(#loc47)) +#loc103 = loc("tmp7"(#loc48)) +#loc104 = loc("tmp7"(#loc49)) +#loc105 = loc("tmp7"(#loc50)) +#loc106 = loc("tmp9"(#loc51)) +#loc107 = loc("tmp11"(#loc52)) +#loc108 = loc(callsite(#loc70 at #loc71)) +#loc109 = loc(callsite(#loc72 at #loc71)) +#loc110 = loc(callsite(#loc73 at #loc71)) +#loc111 = loc(callsite(#loc74 at #loc71)) +#loc112 = loc(callsite(#loc14 at #loc71)) +#loc113 = loc(callsite(#loc15 at #loc71)) +#loc114 = loc(callsite(#loc16 at #loc71)) +#loc115 = loc(callsite(#loc17 at #loc71)) +#loc116 = loc(fused[#loc89, #loc90]) +#loc117 = loc(callsite(#loc42 at #loc99)) +#loc119 = loc(callsite(#loc44 at #loc117)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/__grp__triton_tem_fused_mul_1.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/__grp__triton_tem_fused_mul_1.json new file mode 100644 index 0000000000000000000000000000000000000000..b8de708df3044f8d64b185b415eac1cf37806868 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/__grp__triton_tem_fused_mul_1.json @@ -0,0 +1 @@ +{"child_paths": {"triton_tem_fused_mul_1.source": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.source", "triton_tem_fused_mul_1.ttir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttir", "triton_tem_fused_mul_1.ttgir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttgir", "triton_tem_fused_mul_1.llir": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.llir", "triton_tem_fused_mul_1.ptx": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ptx", "triton_tem_fused_mul_1.cubin": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.cubin", "triton_tem_fused_mul_1.json": "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.json"}} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.json b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.json new file mode 100644 index 0000000000000000000000000000000000000000..0feefd82b7495f911c0fec8a51577428ff377a4e --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.json @@ -0,0 +1 @@ +{"hash": "7a52496e04adcbf63c92845458707deaf98fe12f99af2ab0d113695b8f6c65cd", "target": {"backend": "cuda", "arch": 90, "warp_size": 32}, "num_warps": 8, "num_ctas": 1, "num_stages": 3, "warp_size": 32, "maxnreg": null, "cluster_dims": [1, 1, 1], "ptx_version": null, "ptx_options": null, "ir_override": null, "enable_fp_fusion": true, "launch_cooperative_grid": false, "launch_pdl": false, "supported_fp8_dtypes": ["fp8e4b15", "fp8e4nv", "fp8e5"], "deprecated_fp8_dot_operand_dtypes": ["fp8e4b15"], "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "max_num_imprecise_acc_default": 1073741824, "extern_libs": [["libdevice", "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/backends/nvidia/lib/libdevice.10.bc"]], "debug": true, "backend_name": "cuda", "sanitize_overflow": false, "arch": "sm90", "instrumentation_mode": "", "triton_version": "3.5.1", "tensordesc_meta": [], "shared": 164864, "tmem_size": 0, "global_scratch_size": 0, "global_scratch_align": 1, "profile_scratch_size": 0, "profile_scratch_align": 1, "name": "triton_tem_fused_mul_1"} \ No newline at end of file diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.llir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.llir new file mode 100644 index 0000000000000000000000000000000000000000..4df3797aa3a0f08a60825d404fc049dc510b42f4 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.llir @@ -0,0 +1,14122 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" +target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +@global_smem = external addrspace(3) global [0 x i8], align 16 +@.str = private unnamed_addr constant [11 x i8] c"__CUDA_FTZ\00", align 1 + +; Function Attrs: nounwind +define ptx_kernel void @triton_tem_fused_mul_1(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, ptr addrspace(1) %3, ptr addrspace(1) %4, ptr addrspace(1) %5, ptr addrspace(1) %6, ptr addrspace(1) %7, ptr addrspace(1) %8, ptr addrspace(1) %9, ptr addrspace(1) %10, ptr addrspace(1) %11, ptr addrspace(1) %12, ptr addrspace(1) %13, ptr addrspace(1) %14, ptr addrspace(1) %15, ptr addrspace(1) %16, i32 %17, i32 %18, ptr addrspace(1) readnone captures(none) %19, ptr addrspace(1) readnone captures(none) %20) local_unnamed_addr #0 !dbg !5 { + %22 = shl i32 %17, 12, !dbg !8 + %23 = icmp slt i32 %17, 2, !dbg !9 + %24 = zext i1 %23 to i32, !dbg !10 + %25 = icmp sgt i32 %17, 1, !dbg !11 + %26 = select i1 %25, i32 %17, i32 0, !dbg !12 + %27 = add i32 %26, %24, !dbg !13 + %28 = shl i32 %27, 12, !dbg !14 + %29 = shl i32 %27, 7, !dbg !15 + %30 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !16 + %31 = add i32 %18, 127, !dbg !17 + %32 = sdiv i32 %31, 128, !dbg !21 + %33 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !22 + %34 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !dbg !23 + %35 = shl nuw nsw i32 %34, 7, !dbg !24 + %36 = zext nneg i32 %35 to i64, !dbg !25 + %37 = shl nuw nsw i32 %33, 10, !dbg !26 + %38 = mul i32 %37, %18, !dbg !27 + %39 = add i32 %38, %35, !dbg !28 + %40 = sext i32 %39 to i64, !dbg !29 + %41 = getelementptr bfloat, ptr addrspace(1) %1, i64 %36, !dbg !30 + %42 = getelementptr bfloat, ptr addrspace(1) %2, i64 %36, !dbg !31 + %43 = getelementptr bfloat, ptr addrspace(1) %7, i64 %40, !dbg !32 + %44 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !33 + %45 = lshr i32 %44, 5, !dbg !33 + %46 = and i32 %44, 240, !dbg !33 + %47 = lshr exact i32 %46, 4, !dbg !33 + %48 = or disjoint i32 %47, 16, !dbg !33 + %49 = or disjoint i32 %47, 32, !dbg !33 + %50 = or disjoint i32 %47, 48, !dbg !33 + %51 = or disjoint i32 %47, 64, !dbg !33 + %52 = or disjoint i32 %47, 80, !dbg !33 + %53 = or disjoint i32 %47, 96, !dbg !33 + %54 = or disjoint i32 %47, 112, !dbg !33 + %55 = lshr i32 %44, 1, !dbg !33 + %56 = and i32 %55, 112, !dbg !33 + %57 = lshr i32 %44, 2, !dbg !33 + %58 = and i32 %57, 7, !dbg !33 + %59 = or disjoint i32 %56, %58, !dbg !33 + %60 = or disjoint i32 %59, 8, !dbg !33 + %.not = icmp slt i32 %30, %32, !dbg !34 + br i1 %.not, label %4732, label %61, !dbg !35 + +61: ; preds = %21 + %62 = add i32 %17, 127, !dbg !36 + %63 = sdiv i32 %62, 128, !dbg !38 + %64 = sub i32 %30, %32, !dbg !39 + %.frozen = freeze i32 %64, !dbg !40 + %.frozen4126 = freeze i32 %63, !dbg !40 + %65 = sdiv i32 %.frozen, %.frozen4126, !dbg !40 + %66 = shl nuw nsw i32 %34, 2, !dbg !41 + %67 = add i32 %65, %66, !dbg !42 + %68 = mul i32 %65, %.frozen4126, !dbg !43 + %.decomposed = sub i32 %.frozen, %68, !dbg !43 + %69 = shl i32 %67, 7, !dbg !44 + %70 = mul i32 %22, %33, !dbg !45 + %71 = add i32 %69, %70, !dbg !46 + %72 = sext i32 %71 to i64, !dbg !47 + %73 = mul i32 %67, %29, !dbg !48 + %74 = mul i32 %28, %33, !dbg !49 + %75 = add i32 %73, %74, !dbg !50 + %76 = sext i32 %75 to i64, !dbg !51 + %77 = shl nuw nsw i32 %33, 5, !dbg !52 + %78 = add i32 %67, %77, !dbg !53 + %79 = mul i32 %78, %17, !dbg !54 + %80 = sext i32 %79 to i64, !dbg !55 + %81 = getelementptr bfloat, ptr addrspace(1) %0, i64 %72, !dbg !56 + %82 = getelementptr bfloat, ptr addrspace(1) %5, i64 %76, !dbg !57 + %83 = getelementptr bfloat, ptr addrspace(1) %6, i64 %72, !dbg !58 + %84 = getelementptr float, ptr addrspace(1) %3, i64 %80, !dbg !59 + %85 = getelementptr float, ptr addrspace(1) %4, i64 %80, !dbg !60 + %86 = shl nsw i32 %.decomposed, 7, !dbg !61 + %87 = or disjoint i32 %86, %47, !dbg !62 + %88 = or disjoint i32 %86, %48, !dbg !62 + %89 = or disjoint i32 %86, %49, !dbg !62 + %90 = or disjoint i32 %86, %50, !dbg !62 + %91 = or disjoint i32 %86, %51, !dbg !62 + %92 = or disjoint i32 %86, %52, !dbg !62 + %93 = or disjoint i32 %86, %53, !dbg !62 + %94 = or disjoint i32 %86, %54, !dbg !62 + %95 = or disjoint i32 %86, %59, !dbg !62 + %96 = or disjoint i32 %86, %60, !dbg !62 + %97 = shl i32 %87, 12, !dbg !63 + %98 = shl i32 %88, 12, !dbg !63 + %99 = shl i32 %89, 12, !dbg !63 + %100 = shl i32 %90, 12, !dbg !63 + %101 = shl i32 %91, 12, !dbg !63 + %102 = shl i32 %92, 12, !dbg !63 + %103 = shl i32 %93, 12, !dbg !63 + %104 = shl i32 %94, 12, !dbg !63 + %105 = sext i32 %97 to i64, !dbg !66 + %106 = getelementptr bfloat, ptr addrspace(1) %81, i64 %105, !dbg !66 + %107 = sext i32 %98 to i64, !dbg !66 + %108 = getelementptr bfloat, ptr addrspace(1) %81, i64 %107, !dbg !66 + %109 = sext i32 %99 to i64, !dbg !66 + %110 = getelementptr bfloat, ptr addrspace(1) %81, i64 %109, !dbg !66 + %111 = sext i32 %100 to i64, !dbg !66 + %112 = getelementptr bfloat, ptr addrspace(1) %81, i64 %111, !dbg !66 + %113 = sext i32 %101 to i64, !dbg !66 + %114 = getelementptr bfloat, ptr addrspace(1) %81, i64 %113, !dbg !66 + %115 = sext i32 %102 to i64, !dbg !66 + %116 = getelementptr bfloat, ptr addrspace(1) %81, i64 %115, !dbg !66 + %117 = sext i32 %103 to i64, !dbg !66 + %118 = getelementptr bfloat, ptr addrspace(1) %81, i64 %117, !dbg !66 + %119 = sext i32 %104 to i64, !dbg !66 + %120 = getelementptr bfloat, ptr addrspace(1) %81, i64 %119, !dbg !66 + %121 = shl nuw nsw i32 %44, 3, !dbg !67 + %122 = and i32 %121, 120, !dbg !67 + %123 = zext nneg i32 %122 to i64, !dbg !68 + %124 = getelementptr bfloat, ptr addrspace(1) %106, i64 %123, !dbg !68 + %125 = getelementptr bfloat, ptr addrspace(1) %108, i64 %123, !dbg !68 + %126 = getelementptr bfloat, ptr addrspace(1) %110, i64 %123, !dbg !68 + %127 = getelementptr bfloat, ptr addrspace(1) %112, i64 %123, !dbg !68 + %128 = getelementptr bfloat, ptr addrspace(1) %114, i64 %123, !dbg !68 + %129 = getelementptr bfloat, ptr addrspace(1) %116, i64 %123, !dbg !68 + %130 = getelementptr bfloat, ptr addrspace(1) %118, i64 %123, !dbg !68 + %131 = getelementptr bfloat, ptr addrspace(1) %120, i64 %123, !dbg !68 + %132 = icmp slt i32 %87, %17, !dbg !69 + %133 = icmp slt i32 %88, %17, !dbg !69 + %134 = icmp slt i32 %89, %17, !dbg !69 + %135 = icmp slt i32 %90, %17, !dbg !69 + %136 = icmp slt i32 %91, %17, !dbg !69 + %137 = icmp slt i32 %92, %17, !dbg !69 + %138 = icmp slt i32 %93, %17, !dbg !69 + %139 = icmp slt i32 %94, %17, !dbg !69 + %140 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %124, i1 %132) #3, !dbg !70 + %141 = extractvalue { i32, i32, i32, i32 } %140, 0, !dbg !70 + %142 = extractvalue { i32, i32, i32, i32 } %140, 1, !dbg !70 + %143 = extractvalue { i32, i32, i32, i32 } %140, 2, !dbg !70 + %144 = extractvalue { i32, i32, i32, i32 } %140, 3, !dbg !70 + %145 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %125, i1 %133) #3, !dbg !70 + %146 = extractvalue { i32, i32, i32, i32 } %145, 0, !dbg !70 + %147 = extractvalue { i32, i32, i32, i32 } %145, 1, !dbg !70 + %148 = extractvalue { i32, i32, i32, i32 } %145, 2, !dbg !70 + %149 = extractvalue { i32, i32, i32, i32 } %145, 3, !dbg !70 + %150 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %126, i1 %134) #3, !dbg !70 + %151 = extractvalue { i32, i32, i32, i32 } %150, 0, !dbg !70 + %152 = extractvalue { i32, i32, i32, i32 } %150, 1, !dbg !70 + %153 = extractvalue { i32, i32, i32, i32 } %150, 2, !dbg !70 + %154 = extractvalue { i32, i32, i32, i32 } %150, 3, !dbg !70 + %155 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %127, i1 %135) #3, !dbg !70 + %156 = extractvalue { i32, i32, i32, i32 } %155, 0, !dbg !70 + %157 = extractvalue { i32, i32, i32, i32 } %155, 1, !dbg !70 + %158 = extractvalue { i32, i32, i32, i32 } %155, 2, !dbg !70 + %159 = extractvalue { i32, i32, i32, i32 } %155, 3, !dbg !70 + %160 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %128, i1 %136) #3, !dbg !70 + %161 = extractvalue { i32, i32, i32, i32 } %160, 0, !dbg !70 + %162 = extractvalue { i32, i32, i32, i32 } %160, 1, !dbg !70 + %163 = extractvalue { i32, i32, i32, i32 } %160, 2, !dbg !70 + %164 = extractvalue { i32, i32, i32, i32 } %160, 3, !dbg !70 + %165 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %129, i1 %137) #3, !dbg !70 + %166 = extractvalue { i32, i32, i32, i32 } %165, 0, !dbg !70 + %167 = extractvalue { i32, i32, i32, i32 } %165, 1, !dbg !70 + %168 = extractvalue { i32, i32, i32, i32 } %165, 2, !dbg !70 + %169 = extractvalue { i32, i32, i32, i32 } %165, 3, !dbg !70 + %170 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %130, i1 %138) #3, !dbg !70 + %171 = extractvalue { i32, i32, i32, i32 } %170, 0, !dbg !70 + %172 = extractvalue { i32, i32, i32, i32 } %170, 1, !dbg !70 + %173 = extractvalue { i32, i32, i32, i32 } %170, 2, !dbg !70 + %174 = extractvalue { i32, i32, i32, i32 } %170, 3, !dbg !70 + %175 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %131, i1 %139) #3, !dbg !70 + %176 = extractvalue { i32, i32, i32, i32 } %175, 0, !dbg !70 + %177 = extractvalue { i32, i32, i32, i32 } %175, 1, !dbg !70 + %178 = extractvalue { i32, i32, i32, i32 } %175, 2, !dbg !70 + %179 = extractvalue { i32, i32, i32, i32 } %175, 3, !dbg !70 + %180 = shl nuw nsw i32 %44, 4, !dbg !70 + %181 = and i32 %180, 112, !dbg !70 + %182 = shl nuw nsw i32 %46, 3, !dbg !70 + %183 = and i32 %44, 112, !dbg !70 + %184 = and i32 %44, 8, !dbg !70 + %185 = shl nuw nsw i32 %184, 11, !dbg !70 + %186 = or disjoint i32 %181, %182, !dbg !70 + %187 = xor i32 %186, %183, !dbg !70 + %188 = or disjoint i32 %187, %185, !dbg !70 + %189 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %188, !dbg !70 + %190 = insertelement <4 x i32> poison, i32 %141, i64 0, !dbg !70 + %191 = insertelement <4 x i32> %190, i32 %142, i64 1, !dbg !70 + %192 = insertelement <4 x i32> %191, i32 %143, i64 2, !dbg !70 + %193 = insertelement <4 x i32> %192, i32 %144, i64 3, !dbg !70 + store <4 x i32> %193, ptr addrspace(3) %189, align 16, !dbg !70 + %194 = or disjoint i32 %188, 2048, !dbg !70 + %195 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %194, !dbg !70 + %196 = insertelement <4 x i32> poison, i32 %146, i64 0, !dbg !70 + %197 = insertelement <4 x i32> %196, i32 %147, i64 1, !dbg !70 + %198 = insertelement <4 x i32> %197, i32 %148, i64 2, !dbg !70 + %199 = insertelement <4 x i32> %198, i32 %149, i64 3, !dbg !70 + store <4 x i32> %199, ptr addrspace(3) %195, align 16, !dbg !70 + %200 = or disjoint i32 %188, 4096, !dbg !70 + %201 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %200, !dbg !70 + %202 = insertelement <4 x i32> poison, i32 %151, i64 0, !dbg !70 + %203 = insertelement <4 x i32> %202, i32 %152, i64 1, !dbg !70 + %204 = insertelement <4 x i32> %203, i32 %153, i64 2, !dbg !70 + %205 = insertelement <4 x i32> %204, i32 %154, i64 3, !dbg !70 + store <4 x i32> %205, ptr addrspace(3) %201, align 16, !dbg !70 + %206 = or disjoint i32 %188, 6144, !dbg !70 + %207 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %206, !dbg !70 + %208 = insertelement <4 x i32> poison, i32 %156, i64 0, !dbg !70 + %209 = insertelement <4 x i32> %208, i32 %157, i64 1, !dbg !70 + %210 = insertelement <4 x i32> %209, i32 %158, i64 2, !dbg !70 + %211 = insertelement <4 x i32> %210, i32 %159, i64 3, !dbg !70 + store <4 x i32> %211, ptr addrspace(3) %207, align 16, !dbg !70 + %212 = or disjoint i32 %188, 8192, !dbg !70 + %213 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %212, !dbg !70 + %214 = insertelement <4 x i32> poison, i32 %161, i64 0, !dbg !70 + %215 = insertelement <4 x i32> %214, i32 %162, i64 1, !dbg !70 + %216 = insertelement <4 x i32> %215, i32 %163, i64 2, !dbg !70 + %217 = insertelement <4 x i32> %216, i32 %164, i64 3, !dbg !70 + store <4 x i32> %217, ptr addrspace(3) %213, align 16, !dbg !70 + %218 = or disjoint i32 %188, 10240, !dbg !70 + %219 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %218, !dbg !70 + %220 = insertelement <4 x i32> poison, i32 %166, i64 0, !dbg !70 + %221 = insertelement <4 x i32> %220, i32 %167, i64 1, !dbg !70 + %222 = insertelement <4 x i32> %221, i32 %168, i64 2, !dbg !70 + %223 = insertelement <4 x i32> %222, i32 %169, i64 3, !dbg !70 + store <4 x i32> %223, ptr addrspace(3) %219, align 16, !dbg !70 + %224 = or disjoint i32 %188, 12288, !dbg !70 + %225 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %224, !dbg !70 + %226 = insertelement <4 x i32> poison, i32 %171, i64 0, !dbg !70 + %227 = insertelement <4 x i32> %226, i32 %172, i64 1, !dbg !70 + %228 = insertelement <4 x i32> %227, i32 %173, i64 2, !dbg !70 + %229 = insertelement <4 x i32> %228, i32 %174, i64 3, !dbg !70 + store <4 x i32> %229, ptr addrspace(3) %225, align 16, !dbg !70 + %230 = or disjoint i32 %188, 14336, !dbg !70 + %231 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %230, !dbg !70 + %232 = insertelement <4 x i32> poison, i32 %176, i64 0, !dbg !70 + %233 = insertelement <4 x i32> %232, i32 %177, i64 1, !dbg !70 + %234 = insertelement <4 x i32> %233, i32 %178, i64 2, !dbg !70 + %235 = insertelement <4 x i32> %234, i32 %179, i64 3, !dbg !70 + store <4 x i32> %235, ptr addrspace(3) %231, align 16, !dbg !70 + %236 = shl i32 %87, 7, !dbg !71 + %237 = shl i32 %88, 7, !dbg !71 + %238 = shl i32 %89, 7, !dbg !71 + %239 = shl i32 %90, 7, !dbg !71 + %240 = shl i32 %91, 7, !dbg !71 + %241 = shl i32 %92, 7, !dbg !71 + %242 = shl i32 %93, 7, !dbg !71 + %243 = shl i32 %94, 7, !dbg !71 + %244 = sext i32 %236 to i64, !dbg !73 + %245 = getelementptr bfloat, ptr addrspace(1) %82, i64 %244, !dbg !73 + %246 = sext i32 %237 to i64, !dbg !73 + %247 = getelementptr bfloat, ptr addrspace(1) %82, i64 %246, !dbg !73 + %248 = sext i32 %238 to i64, !dbg !73 + %249 = getelementptr bfloat, ptr addrspace(1) %82, i64 %248, !dbg !73 + %250 = sext i32 %239 to i64, !dbg !73 + %251 = getelementptr bfloat, ptr addrspace(1) %82, i64 %250, !dbg !73 + %252 = sext i32 %240 to i64, !dbg !73 + %253 = getelementptr bfloat, ptr addrspace(1) %82, i64 %252, !dbg !73 + %254 = sext i32 %241 to i64, !dbg !73 + %255 = getelementptr bfloat, ptr addrspace(1) %82, i64 %254, !dbg !73 + %256 = sext i32 %242 to i64, !dbg !73 + %257 = getelementptr bfloat, ptr addrspace(1) %82, i64 %256, !dbg !73 + %258 = sext i32 %243 to i64, !dbg !73 + %259 = getelementptr bfloat, ptr addrspace(1) %82, i64 %258, !dbg !73 + %260 = getelementptr bfloat, ptr addrspace(1) %245, i64 %123, !dbg !74 + %261 = getelementptr bfloat, ptr addrspace(1) %247, i64 %123, !dbg !74 + %262 = getelementptr bfloat, ptr addrspace(1) %249, i64 %123, !dbg !74 + %263 = getelementptr bfloat, ptr addrspace(1) %251, i64 %123, !dbg !74 + %264 = getelementptr bfloat, ptr addrspace(1) %253, i64 %123, !dbg !74 + %265 = getelementptr bfloat, ptr addrspace(1) %255, i64 %123, !dbg !74 + %266 = getelementptr bfloat, ptr addrspace(1) %257, i64 %123, !dbg !74 + %267 = getelementptr bfloat, ptr addrspace(1) %259, i64 %123, !dbg !74 + %268 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %260, i1 %132) #3, !dbg !75 + %269 = extractvalue { i32, i32, i32, i32 } %268, 0, !dbg !75 + %270 = extractvalue { i32, i32, i32, i32 } %268, 1, !dbg !75 + %271 = extractvalue { i32, i32, i32, i32 } %268, 2, !dbg !75 + %272 = extractvalue { i32, i32, i32, i32 } %268, 3, !dbg !75 + %273 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %261, i1 %133) #3, !dbg !75 + %274 = extractvalue { i32, i32, i32, i32 } %273, 0, !dbg !75 + %275 = extractvalue { i32, i32, i32, i32 } %273, 1, !dbg !75 + %276 = extractvalue { i32, i32, i32, i32 } %273, 2, !dbg !75 + %277 = extractvalue { i32, i32, i32, i32 } %273, 3, !dbg !75 + %278 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %262, i1 %134) #3, !dbg !75 + %279 = extractvalue { i32, i32, i32, i32 } %278, 0, !dbg !75 + %280 = extractvalue { i32, i32, i32, i32 } %278, 1, !dbg !75 + %281 = extractvalue { i32, i32, i32, i32 } %278, 2, !dbg !75 + %282 = extractvalue { i32, i32, i32, i32 } %278, 3, !dbg !75 + %283 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %263, i1 %135) #3, !dbg !75 + %284 = extractvalue { i32, i32, i32, i32 } %283, 0, !dbg !75 + %285 = extractvalue { i32, i32, i32, i32 } %283, 1, !dbg !75 + %286 = extractvalue { i32, i32, i32, i32 } %283, 2, !dbg !75 + %287 = extractvalue { i32, i32, i32, i32 } %283, 3, !dbg !75 + %288 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %264, i1 %136) #3, !dbg !75 + %289 = extractvalue { i32, i32, i32, i32 } %288, 0, !dbg !75 + %290 = extractvalue { i32, i32, i32, i32 } %288, 1, !dbg !75 + %291 = extractvalue { i32, i32, i32, i32 } %288, 2, !dbg !75 + %292 = extractvalue { i32, i32, i32, i32 } %288, 3, !dbg !75 + %293 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %265, i1 %137) #3, !dbg !75 + %294 = extractvalue { i32, i32, i32, i32 } %293, 0, !dbg !75 + %295 = extractvalue { i32, i32, i32, i32 } %293, 1, !dbg !75 + %296 = extractvalue { i32, i32, i32, i32 } %293, 2, !dbg !75 + %297 = extractvalue { i32, i32, i32, i32 } %293, 3, !dbg !75 + %298 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %266, i1 %138) #3, !dbg !75 + %299 = extractvalue { i32, i32, i32, i32 } %298, 0, !dbg !75 + %300 = extractvalue { i32, i32, i32, i32 } %298, 1, !dbg !75 + %301 = extractvalue { i32, i32, i32, i32 } %298, 2, !dbg !75 + %302 = extractvalue { i32, i32, i32, i32 } %298, 3, !dbg !75 + %303 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %267, i1 %139) #3, !dbg !75 + %304 = extractvalue { i32, i32, i32, i32 } %303, 0, !dbg !75 + %305 = extractvalue { i32, i32, i32, i32 } %303, 1, !dbg !75 + %306 = extractvalue { i32, i32, i32, i32 } %303, 2, !dbg !75 + %307 = extractvalue { i32, i32, i32, i32 } %303, 3, !dbg !75 + %308 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %188, !dbg !75 + %309 = insertelement <4 x i32> poison, i32 %269, i64 0, !dbg !75 + %310 = insertelement <4 x i32> %309, i32 %270, i64 1, !dbg !75 + %311 = insertelement <4 x i32> %310, i32 %271, i64 2, !dbg !75 + %312 = insertelement <4 x i32> %311, i32 %272, i64 3, !dbg !75 + store <4 x i32> %312, ptr addrspace(3) %308, align 16, !dbg !75 + %313 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %194, !dbg !75 + %314 = insertelement <4 x i32> poison, i32 %274, i64 0, !dbg !75 + %315 = insertelement <4 x i32> %314, i32 %275, i64 1, !dbg !75 + %316 = insertelement <4 x i32> %315, i32 %276, i64 2, !dbg !75 + %317 = insertelement <4 x i32> %316, i32 %277, i64 3, !dbg !75 + store <4 x i32> %317, ptr addrspace(3) %313, align 16, !dbg !75 + %318 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %200, !dbg !75 + %319 = insertelement <4 x i32> poison, i32 %279, i64 0, !dbg !75 + %320 = insertelement <4 x i32> %319, i32 %280, i64 1, !dbg !75 + %321 = insertelement <4 x i32> %320, i32 %281, i64 2, !dbg !75 + %322 = insertelement <4 x i32> %321, i32 %282, i64 3, !dbg !75 + store <4 x i32> %322, ptr addrspace(3) %318, align 16, !dbg !75 + %323 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %206, !dbg !75 + %324 = insertelement <4 x i32> poison, i32 %284, i64 0, !dbg !75 + %325 = insertelement <4 x i32> %324, i32 %285, i64 1, !dbg !75 + %326 = insertelement <4 x i32> %325, i32 %286, i64 2, !dbg !75 + %327 = insertelement <4 x i32> %326, i32 %287, i64 3, !dbg !75 + store <4 x i32> %327, ptr addrspace(3) %323, align 16, !dbg !75 + %328 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %212, !dbg !75 + %329 = insertelement <4 x i32> poison, i32 %289, i64 0, !dbg !75 + %330 = insertelement <4 x i32> %329, i32 %290, i64 1, !dbg !75 + %331 = insertelement <4 x i32> %330, i32 %291, i64 2, !dbg !75 + %332 = insertelement <4 x i32> %331, i32 %292, i64 3, !dbg !75 + store <4 x i32> %332, ptr addrspace(3) %328, align 16, !dbg !75 + %333 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %218, !dbg !75 + %334 = insertelement <4 x i32> poison, i32 %294, i64 0, !dbg !75 + %335 = insertelement <4 x i32> %334, i32 %295, i64 1, !dbg !75 + %336 = insertelement <4 x i32> %335, i32 %296, i64 2, !dbg !75 + %337 = insertelement <4 x i32> %336, i32 %297, i64 3, !dbg !75 + store <4 x i32> %337, ptr addrspace(3) %333, align 16, !dbg !75 + %338 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %224, !dbg !75 + %339 = insertelement <4 x i32> poison, i32 %299, i64 0, !dbg !75 + %340 = insertelement <4 x i32> %339, i32 %300, i64 1, !dbg !75 + %341 = insertelement <4 x i32> %340, i32 %301, i64 2, !dbg !75 + %342 = insertelement <4 x i32> %341, i32 %302, i64 3, !dbg !75 + store <4 x i32> %342, ptr addrspace(3) %338, align 16, !dbg !75 + %343 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 %230, !dbg !75 + %344 = insertelement <4 x i32> poison, i32 %304, i64 0, !dbg !75 + %345 = insertelement <4 x i32> %344, i32 %305, i64 1, !dbg !75 + %346 = insertelement <4 x i32> %345, i32 %306, i64 2, !dbg !75 + %347 = insertelement <4 x i32> %346, i32 %307, i64 3, !dbg !75 + store <4 x i32> %347, ptr addrspace(3) %343, align 16, !dbg !75 + %348 = icmp slt i32 %95, %17, !dbg !76 + %349 = icmp slt i32 %96, %17, !dbg !76 + %350 = sext i32 %95 to i64, !dbg !77 + %351 = getelementptr float, ptr addrspace(1) %85, i64 %350, !dbg !77 + %352 = sext i32 %96 to i64, !dbg !77 + %353 = getelementptr float, ptr addrspace(1) %85, i64 %352, !dbg !77 + %354 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %351, i1 %348) #3, !dbg !78 + %355 = bitcast i32 %354 to float, !dbg !78 + %356 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %353, i1 %349) #3, !dbg !78 + %357 = bitcast i32 %356 to float, !dbg !78 + %358 = getelementptr float, ptr addrspace(1) %84, i64 %350, !dbg !79 + %359 = getelementptr float, ptr addrspace(1) %84, i64 %352, !dbg !79 + %360 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %358, i1 %348) #3, !dbg !80 + %361 = bitcast i32 %360 to float, !dbg !80 + %362 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %359, i1 %349) #3, !dbg !80 + %363 = bitcast i32 %362 to float, !dbg !80 + %364 = fcmp oeq float %361, 0xFFF0000000000000, !dbg !81 + %365 = fcmp oeq float %363, 0xFFF0000000000000, !dbg !81 + %366 = select i1 %364, float 0.000000e+00, float %361, !dbg !82 + %367 = select i1 %365, float 0.000000e+00, float %363, !dbg !82 + %368 = sext i32 %.decomposed to i64, !dbg !83 + %369 = getelementptr i32, ptr addrspace(1) %9, i64 %368, !dbg !83 + %370 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %369) #3, !dbg !84 + %371 = shl i32 %370, 7, !dbg !85 + %372 = getelementptr i32, ptr addrspace(1) %8, i64 %368, !dbg !86 + %373 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %372) #3, !dbg !87 + %374 = and i32 %44, 3, !dbg !88 + %375 = shl nuw nsw i32 %374, 1, !dbg !88 + %376 = or disjoint i32 %375, 1, !dbg !88 + %377 = insertelement <2 x i32> poison, i32 %375, i64 0, !dbg !88 + %378 = shufflevector <2 x i32> %377, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !88 + %379 = or disjoint <2 x i32> %378, , !dbg !88 + %380 = insertelement <4 x i32> poison, i32 %375, i64 0, !dbg !88 + %381 = shufflevector <4 x i32> %380, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !88 + %382 = or disjoint <4 x i32> %381, , !dbg !88 + %383 = insertelement <8 x i32> poison, i32 %375, i64 0, !dbg !88 + %384 = shufflevector <8 x i32> %383, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !88 + %385 = or disjoint <8 x i32> %384, , !dbg !88 + %386 = or disjoint i32 %371, %47, !dbg !89 + %387 = or disjoint i32 %371, %48, !dbg !89 + %388 = or disjoint i32 %371, %49, !dbg !89 + %389 = or disjoint i32 %371, %50, !dbg !89 + %390 = shl i32 %386, 10, !dbg !90 + %391 = shl i32 %387, 10, !dbg !90 + %392 = shl i32 %388, 10, !dbg !90 + %393 = shl i32 %389, 10, !dbg !90 + %394 = sext i32 %390 to i64, !dbg !92 + %395 = getelementptr bfloat, ptr addrspace(1) %41, i64 %394, !dbg !92 + %396 = sext i32 %391 to i64, !dbg !92 + %397 = getelementptr bfloat, ptr addrspace(1) %41, i64 %396, !dbg !92 + %398 = sext i32 %392 to i64, !dbg !92 + %399 = getelementptr bfloat, ptr addrspace(1) %41, i64 %398, !dbg !92 + %400 = sext i32 %393 to i64, !dbg !92 + %401 = getelementptr bfloat, ptr addrspace(1) %41, i64 %400, !dbg !92 + %402 = getelementptr bfloat, ptr addrspace(1) %395, i64 %123, !dbg !93 + %403 = getelementptr bfloat, ptr addrspace(1) %397, i64 %123, !dbg !93 + %404 = getelementptr bfloat, ptr addrspace(1) %399, i64 %123, !dbg !93 + %405 = getelementptr bfloat, ptr addrspace(1) %401, i64 %123, !dbg !93 + %406 = getelementptr bfloat, ptr addrspace(1) %42, i64 %394, !dbg !94 + %407 = getelementptr bfloat, ptr addrspace(1) %42, i64 %396, !dbg !94 + %408 = getelementptr bfloat, ptr addrspace(1) %42, i64 %398, !dbg !94 + %409 = getelementptr bfloat, ptr addrspace(1) %42, i64 %400, !dbg !94 + %410 = getelementptr bfloat, ptr addrspace(1) %406, i64 %123, !dbg !95 + %411 = getelementptr bfloat, ptr addrspace(1) %407, i64 %123, !dbg !95 + %412 = getelementptr bfloat, ptr addrspace(1) %408, i64 %123, !dbg !95 + %413 = getelementptr bfloat, ptr addrspace(1) %409, i64 %123, !dbg !95 + %414 = shl i32 %373, 1, !dbg !96 + %415 = add i32 %18, 63, !dbg !97 + %416 = sdiv i32 %415, 64, !dbg !98 + %417 = tail call i32 @llvm.smax.i32(i32 %416, i32 1), !dbg !99 + %418 = tail call i32 @llvm.smin.i32(i32 %414, i32 %417), !dbg !100 + %419 = icmp sgt i32 %414, 0, !dbg !101 + %420 = icmp slt i32 %386, %18, !dbg !102 + %421 = icmp slt i32 %387, %18, !dbg !102 + %422 = icmp slt i32 %388, %18, !dbg !102 + %423 = icmp slt i32 %389, %18, !dbg !102 + %424 = and i1 %419, %420, !dbg !101 + %425 = and i1 %419, %421, !dbg !101 + %426 = and i1 %419, %422, !dbg !101 + %427 = and i1 %419, %423, !dbg !101 + %428 = shl nuw nsw i32 %184, 10, !dbg !103 + %429 = or disjoint i32 %187, %428, !dbg !103 + %430 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %429, !dbg !103 + %431 = select i1 %424, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %430, ptr addrspace(1) %402, i32 %431) #3, !dbg !103 + %432 = or disjoint i32 %429, 2048, !dbg !103 + %433 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %432, !dbg !103 + %434 = select i1 %425, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %433, ptr addrspace(1) %403, i32 %434) #3, !dbg !103 + %435 = or disjoint i32 %429, 4096, !dbg !103 + %436 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %435, !dbg !103 + %437 = select i1 %426, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %436, ptr addrspace(1) %404, i32 %437) #3, !dbg !103 + %438 = or disjoint i32 %429, 6144, !dbg !103 + %439 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %438, !dbg !103 + %440 = select i1 %427, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %439, ptr addrspace(1) %405, i32 %440) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + %441 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %429, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %441, ptr addrspace(1) %410, i32 %431) #3, !dbg !103 + %442 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %432, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %442, ptr addrspace(1) %411, i32 %434) #3, !dbg !103 + %443 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %435, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %443, ptr addrspace(1) %412, i32 %437) #3, !dbg !103 + %444 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %438, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %444, ptr addrspace(1) %413, i32 %440) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + %445 = icmp sgt i32 %418, 1, !dbg !101 + %446 = getelementptr i8, ptr addrspace(1) %402, i64 131072, !dbg !104 + %447 = getelementptr i8, ptr addrspace(1) %403, i64 131072, !dbg !104 + %448 = getelementptr i8, ptr addrspace(1) %404, i64 131072, !dbg !104 + %449 = getelementptr i8, ptr addrspace(1) %405, i64 131072, !dbg !104 + %450 = getelementptr i8, ptr addrspace(1) %410, i64 131072, !dbg !105 + %451 = getelementptr i8, ptr addrspace(1) %411, i64 131072, !dbg !105 + %452 = getelementptr i8, ptr addrspace(1) %412, i64 131072, !dbg !105 + %453 = getelementptr i8, ptr addrspace(1) %413, i64 131072, !dbg !105 + %454 = or disjoint i32 %386, 64, !dbg !106 + %455 = or disjoint i32 %387, 64, !dbg !106 + %456 = or disjoint i32 %388, 64, !dbg !106 + %457 = or disjoint i32 %389, 64, !dbg !106 + %458 = icmp slt i32 %454, %18, !dbg !102 + %459 = icmp slt i32 %455, %18, !dbg !102 + %460 = icmp slt i32 %456, %18, !dbg !102 + %461 = icmp slt i32 %457, %18, !dbg !102 + %462 = and i1 %445, %458, !dbg !101 + %463 = and i1 %445, %459, !dbg !101 + %464 = and i1 %445, %460, !dbg !101 + %465 = and i1 %445, %461, !dbg !101 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !103 + %466 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %429, !dbg !103 + %467 = select i1 %462, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %466, ptr addrspace(1) %446, i32 %467) #3, !dbg !103 + %468 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %432, !dbg !103 + %469 = select i1 %463, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %468, ptr addrspace(1) %447, i32 %469) #3, !dbg !103 + %470 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %435, !dbg !103 + %471 = select i1 %464, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %470, ptr addrspace(1) %448, i32 %471) #3, !dbg !103 + %472 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %438, !dbg !103 + %473 = select i1 %465, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %472, ptr addrspace(1) %449, i32 %473) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + %474 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %429, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %474, ptr addrspace(1) %450, i32 %467) #3, !dbg !103 + %475 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %432, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %475, ptr addrspace(1) %451, i32 %469) #3, !dbg !103 + %476 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %435, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %476, ptr addrspace(1) %452, i32 %471) #3, !dbg !103 + %477 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %438, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %477, ptr addrspace(1) %453, i32 %473) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !107 + br i1 %419, label %.lr.ph, label %._crit_edge, !dbg !101 + +.lr.ph: ; preds = %61 + %478 = srem i32 %96, %17, !dbg !108 + %479 = sdiv i32 %478, 16, !dbg !109 + %480 = icmp slt i32 %478, 0, !dbg !110 + %481 = and i32 %478, 15, !dbg !111 + %.not844 = icmp ne i32 %481, 0, !dbg !111 + %narrow846 = and i1 %480, %.not844, !dbg !112 + %482 = sext i1 %narrow846 to i32, !dbg !112 + %483 = add nsw i32 %479, %482, !dbg !112 + %484 = srem i32 %95, %17, !dbg !108 + %485 = sdiv i32 %484, 16, !dbg !109 + %486 = icmp slt i32 %484, 0, !dbg !110 + %487 = and i32 %484, 15, !dbg !111 + %.not843 = icmp ne i32 %487, 0, !dbg !111 + %narrow845 = and i1 %486, %.not843, !dbg !112 + %488 = sext i1 %narrow845 to i32, !dbg !112 + %489 = add nsw i32 %485, %488, !dbg !112 + %490 = icmp sgt i32 %478, -1, !dbg !113 + %491 = icmp sgt i32 %484, -1, !dbg !113 + %492 = insertelement <2 x i32> poison, i32 %371, i64 0, !dbg !89 + %493 = shufflevector <2 x i32> %492, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !89 + %494 = shufflevector <8 x i32> %385, <8 x i32> poison, <2 x i32> , !dbg !89 + %495 = or disjoint <2 x i32> %493, %494, !dbg !89 + %496 = shufflevector <8 x i32> %385, <8 x i32> poison, <2 x i32> , !dbg !89 + %497 = or disjoint <2 x i32> %493, %496, !dbg !89 + %498 = shufflevector <8 x i32> %385, <8 x i32> poison, <2 x i32> , !dbg !89 + %499 = or disjoint <2 x i32> %493, %498, !dbg !89 + %500 = shufflevector <8 x i32> %385, <8 x i32> poison, <2 x i32> , !dbg !89 + %501 = or disjoint <2 x i32> %493, %500, !dbg !89 + %502 = shufflevector <4 x i32> %382, <4 x i32> poison, <2 x i32> , !dbg !89 + %503 = or disjoint <2 x i32> %493, %502, !dbg !89 + %504 = shufflevector <4 x i32> %382, <4 x i32> poison, <2 x i32> , !dbg !89 + %505 = or disjoint <2 x i32> %493, %504, !dbg !89 + %506 = or disjoint <2 x i32> %493, %379, !dbg !89 + %507 = insertelement <2 x i32> %377, i32 %376, i64 1, !dbg !89 + %508 = or disjoint <2 x i32> %493, %507, !dbg !89 + %509 = add nsw i32 %418, -2 + %510 = add nsw i32 %418, -1 + %smax = tail call i32 @llvm.smax.i32(i32 %418, i32 1), !dbg !101 + %511 = insertelement <2 x i1> poison, i1 %480, i64 0, !dbg !114 + %512 = shufflevector <2 x i1> %511, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !114 + %513 = insertelement <2 x i32> poison, i32 %478, i64 0, !dbg !115 + %514 = shufflevector <2 x i32> %513, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !115 + %515 = insertelement <2 x i1> poison, i1 %490, i64 0, !dbg !116 + %516 = shufflevector <2 x i1> %515, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !116 + %517 = insertelement <2 x i32> poison, i32 %483, i64 0, !dbg !117 + %518 = shufflevector <2 x i32> %517, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !117 + %519 = insertelement <2 x i32> poison, i32 %18, i64 0, !dbg !102 + %520 = shufflevector <2 x i32> %519, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !102 + %521 = insertelement <2 x float> poison, float %357, i64 0, !dbg !118 + %522 = shufflevector <2 x float> %521, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !118 + %523 = insertelement <2 x i1> poison, i1 %486, i64 0, !dbg !114 + %524 = shufflevector <2 x i1> %523, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !114 + %525 = insertelement <2 x i32> poison, i32 %484, i64 0, !dbg !115 + %526 = shufflevector <2 x i32> %525, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !115 + %527 = insertelement <2 x i1> poison, i1 %491, i64 0, !dbg !116 + %528 = shufflevector <2 x i1> %527, <2 x i1> poison, <2 x i32> zeroinitializer, !dbg !116 + %529 = insertelement <2 x i32> poison, i32 %489, i64 0, !dbg !117 + %530 = shufflevector <2 x i32> %529, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !117 + %531 = insertelement <2 x float> poison, float %355, i64 0, !dbg !118 + %532 = shufflevector <2 x float> %531, <2 x float> poison, <2 x i32> zeroinitializer, !dbg !118 + br label %533, !dbg !101 + +533: ; preds = %.lr.ph, %__nv_exp2f.exit1609 + %534 = phi i32 [ 64, %.lr.ph ], [ %2382, %__nv_exp2f.exit1609 ] + %535 = phi i32 [ -1, %.lr.ph ], [ %614, %__nv_exp2f.exit1609 ] + %536 = phi i32 [ 1, %.lr.ph ], [ %2399, %__nv_exp2f.exit1609 ] + %.pn9331626 = phi ptr addrspace(1) [ %453, %.lr.ph ], [ %2392, %__nv_exp2f.exit1609 ] + %.pn9491625 = phi ptr addrspace(1) [ %452, %.lr.ph ], [ %2391, %__nv_exp2f.exit1609 ] + %.pn9651624 = phi ptr addrspace(1) [ %451, %.lr.ph ], [ %2390, %__nv_exp2f.exit1609 ] + %.pn9811623 = phi ptr addrspace(1) [ %450, %.lr.ph ], [ %2389, %__nv_exp2f.exit1609 ] + %.pn9111622 = phi i32 [ %457, %.lr.ph ], [ %2396, %__nv_exp2f.exit1609 ] + %.pn9131621 = phi i32 [ %456, %.lr.ph ], [ %2395, %__nv_exp2f.exit1609 ] + %.pn9151620 = phi i32 [ %455, %.lr.ph ], [ %2394, %__nv_exp2f.exit1609 ] + %.pn9171619 = phi i32 [ %454, %.lr.ph ], [ %2393, %__nv_exp2f.exit1609 ] + %.pn8611618 = phi ptr addrspace(1) [ %449, %.lr.ph ], [ %2388, %__nv_exp2f.exit1609 ] + %.pn8771617 = phi ptr addrspace(1) [ %448, %.lr.ph ], [ %2387, %__nv_exp2f.exit1609 ] + %.pn8931616 = phi ptr addrspace(1) [ %447, %.lr.ph ], [ %2386, %__nv_exp2f.exit1609 ] + %.pn9091615 = phi ptr addrspace(1) [ %446, %.lr.ph ], [ %2385, %__nv_exp2f.exit1609 ] + %537 = phi float [ 0.000000e+00, %.lr.ph ], [ %2289, %__nv_exp2f.exit1609 ] + %538 = phi float [ 0.000000e+00, %.lr.ph ], [ %2290, %__nv_exp2f.exit1609 ] + %539 = phi float [ 0.000000e+00, %.lr.ph ], [ %2291, %__nv_exp2f.exit1609 ] + %540 = phi float [ 0.000000e+00, %.lr.ph ], [ %2292, %__nv_exp2f.exit1609 ] + %541 = phi float [ 0.000000e+00, %.lr.ph ], [ %2293, %__nv_exp2f.exit1609 ] + %542 = phi float [ 0.000000e+00, %.lr.ph ], [ %2294, %__nv_exp2f.exit1609 ] + %543 = phi float [ 0.000000e+00, %.lr.ph ], [ %2295, %__nv_exp2f.exit1609 ] + %544 = phi float [ 0.000000e+00, %.lr.ph ], [ %2296, %__nv_exp2f.exit1609 ] + %545 = phi float [ 0.000000e+00, %.lr.ph ], [ %2297, %__nv_exp2f.exit1609 ] + %546 = phi float [ 0.000000e+00, %.lr.ph ], [ %2298, %__nv_exp2f.exit1609 ] + %547 = phi float [ 0.000000e+00, %.lr.ph ], [ %2299, %__nv_exp2f.exit1609 ] + %548 = phi float [ 0.000000e+00, %.lr.ph ], [ %2300, %__nv_exp2f.exit1609 ] + %549 = phi float [ 0.000000e+00, %.lr.ph ], [ %2301, %__nv_exp2f.exit1609 ] + %550 = phi float [ 0.000000e+00, %.lr.ph ], [ %2302, %__nv_exp2f.exit1609 ] + %551 = phi float [ 0.000000e+00, %.lr.ph ], [ %2303, %__nv_exp2f.exit1609 ] + %552 = phi float [ 0.000000e+00, %.lr.ph ], [ %2304, %__nv_exp2f.exit1609 ] + %553 = phi float [ 0.000000e+00, %.lr.ph ], [ %2305, %__nv_exp2f.exit1609 ] + %554 = phi float [ 0.000000e+00, %.lr.ph ], [ %2306, %__nv_exp2f.exit1609 ] + %555 = phi float [ 0.000000e+00, %.lr.ph ], [ %2307, %__nv_exp2f.exit1609 ] + %556 = phi float [ 0.000000e+00, %.lr.ph ], [ %2308, %__nv_exp2f.exit1609 ] + %557 = phi float [ 0.000000e+00, %.lr.ph ], [ %2309, %__nv_exp2f.exit1609 ] + %558 = phi float [ 0.000000e+00, %.lr.ph ], [ %2310, %__nv_exp2f.exit1609 ] + %559 = phi float [ 0.000000e+00, %.lr.ph ], [ %2311, %__nv_exp2f.exit1609 ] + %560 = phi float [ 0.000000e+00, %.lr.ph ], [ %2312, %__nv_exp2f.exit1609 ] + %561 = phi float [ 0.000000e+00, %.lr.ph ], [ %2313, %__nv_exp2f.exit1609 ] + %562 = phi float [ 0.000000e+00, %.lr.ph ], [ %2314, %__nv_exp2f.exit1609 ] + %563 = phi float [ 0.000000e+00, %.lr.ph ], [ %2315, %__nv_exp2f.exit1609 ] + %564 = phi float [ 0.000000e+00, %.lr.ph ], [ %2316, %__nv_exp2f.exit1609 ] + %565 = phi float [ 0.000000e+00, %.lr.ph ], [ %2317, %__nv_exp2f.exit1609 ] + %566 = phi float [ 0.000000e+00, %.lr.ph ], [ %2318, %__nv_exp2f.exit1609 ] + %567 = phi float [ 0.000000e+00, %.lr.ph ], [ %2319, %__nv_exp2f.exit1609 ] + %568 = phi float [ 0.000000e+00, %.lr.ph ], [ %2320, %__nv_exp2f.exit1609 ] + %569 = phi float [ 0.000000e+00, %.lr.ph ], [ %2321, %__nv_exp2f.exit1609 ] + %570 = phi float [ 0.000000e+00, %.lr.ph ], [ %2322, %__nv_exp2f.exit1609 ] + %571 = phi float [ 0.000000e+00, %.lr.ph ], [ %2323, %__nv_exp2f.exit1609 ] + %572 = phi float [ 0.000000e+00, %.lr.ph ], [ %2324, %__nv_exp2f.exit1609 ] + %573 = phi float [ 0.000000e+00, %.lr.ph ], [ %2325, %__nv_exp2f.exit1609 ] + %574 = phi float [ 0.000000e+00, %.lr.ph ], [ %2326, %__nv_exp2f.exit1609 ] + %575 = phi float [ 0.000000e+00, %.lr.ph ], [ %2327, %__nv_exp2f.exit1609 ] + %576 = phi float [ 0.000000e+00, %.lr.ph ], [ %2328, %__nv_exp2f.exit1609 ] + %577 = phi float [ 0.000000e+00, %.lr.ph ], [ %2329, %__nv_exp2f.exit1609 ] + %578 = phi float [ 0.000000e+00, %.lr.ph ], [ %2330, %__nv_exp2f.exit1609 ] + %579 = phi float [ 0.000000e+00, %.lr.ph ], [ %2331, %__nv_exp2f.exit1609 ] + %580 = phi float [ 0.000000e+00, %.lr.ph ], [ %2332, %__nv_exp2f.exit1609 ] + %581 = phi float [ 0.000000e+00, %.lr.ph ], [ %2333, %__nv_exp2f.exit1609 ] + %582 = phi float [ 0.000000e+00, %.lr.ph ], [ %2334, %__nv_exp2f.exit1609 ] + %583 = phi float [ 0.000000e+00, %.lr.ph ], [ %2335, %__nv_exp2f.exit1609 ] + %584 = phi float [ 0.000000e+00, %.lr.ph ], [ %2336, %__nv_exp2f.exit1609 ] + %585 = phi float [ 0.000000e+00, %.lr.ph ], [ %2337, %__nv_exp2f.exit1609 ] + %586 = phi float [ 0.000000e+00, %.lr.ph ], [ %2338, %__nv_exp2f.exit1609 ] + %587 = phi float [ 0.000000e+00, %.lr.ph ], [ %2339, %__nv_exp2f.exit1609 ] + %588 = phi float [ 0.000000e+00, %.lr.ph ], [ %2340, %__nv_exp2f.exit1609 ] + %589 = phi float [ 0.000000e+00, %.lr.ph ], [ %2341, %__nv_exp2f.exit1609 ] + %590 = phi float [ 0.000000e+00, %.lr.ph ], [ %2342, %__nv_exp2f.exit1609 ] + %591 = phi float [ 0.000000e+00, %.lr.ph ], [ %2343, %__nv_exp2f.exit1609 ] + %592 = phi float [ 0.000000e+00, %.lr.ph ], [ %2344, %__nv_exp2f.exit1609 ] + %593 = phi float [ 0.000000e+00, %.lr.ph ], [ %2345, %__nv_exp2f.exit1609 ] + %594 = phi float [ 0.000000e+00, %.lr.ph ], [ %2346, %__nv_exp2f.exit1609 ] + %595 = phi float [ 0.000000e+00, %.lr.ph ], [ %2347, %__nv_exp2f.exit1609 ] + %596 = phi float [ 0.000000e+00, %.lr.ph ], [ %2348, %__nv_exp2f.exit1609 ] + %597 = phi float [ 0.000000e+00, %.lr.ph ], [ %2349, %__nv_exp2f.exit1609 ] + %598 = phi float [ 0.000000e+00, %.lr.ph ], [ %2350, %__nv_exp2f.exit1609 ] + %599 = phi float [ 0.000000e+00, %.lr.ph ], [ %2351, %__nv_exp2f.exit1609 ] + %600 = phi float [ 0.000000e+00, %.lr.ph ], [ %2352, %__nv_exp2f.exit1609 ] + %601 = phi i32 [ 0, %.lr.ph ], [ %2363, %__nv_exp2f.exit1609 ] + %602 = phi <2 x i32> [ %495, %.lr.ph ], [ %2362, %__nv_exp2f.exit1609 ] + %603 = phi <2 x i32> [ %497, %.lr.ph ], [ %2361, %__nv_exp2f.exit1609 ] + %604 = phi <2 x i32> [ %499, %.lr.ph ], [ %2360, %__nv_exp2f.exit1609 ] + %605 = phi <2 x i32> [ %501, %.lr.ph ], [ %2359, %__nv_exp2f.exit1609 ] + %606 = phi <2 x i32> [ %503, %.lr.ph ], [ %2358, %__nv_exp2f.exit1609 ] + %607 = phi <2 x i32> [ %505, %.lr.ph ], [ %2357, %__nv_exp2f.exit1609 ] + %608 = phi <2 x i32> [ %506, %.lr.ph ], [ %2356, %__nv_exp2f.exit1609 ] + %609 = phi <2 x i32> [ %508, %.lr.ph ], [ %2355, %__nv_exp2f.exit1609 ] + %610 = icmp slt i32 %601, %509, !dbg !101 + %611 = icmp slt i32 %601, %510, !dbg !101 + %612 = add i32 %535, 1, !dbg !101 + %613 = icmp sgt i32 %612, 2, !dbg !101 + %614 = select i1 %613, i32 0, i32 %612, !dbg !101 + %615 = icmp slt <2 x i32> %609, %520, !dbg !102 + %616 = icmp slt <2 x i32> %608, %520, !dbg !102 + %617 = icmp slt <2 x i32> %607, %520, !dbg !102 + %618 = icmp slt <2 x i32> %606, %520, !dbg !102 + %619 = icmp slt <2 x i32> %605, %520, !dbg !102 + %620 = icmp slt <2 x i32> %604, %520, !dbg !102 + %621 = icmp slt <2 x i32> %603, %520, !dbg !102 + %622 = icmp slt <2 x i32> %602, %520, !dbg !102 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !103 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !103 + %623 = shl i32 %614, 13, !dbg !103 + %624 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %623, !dbg !103 + %625 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %45, i32 0, i32 31), !dbg !107 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !107 + %626 = shl i32 %625, 11, !dbg !107 + %627 = and i32 %626, 8192, !dbg !107 + %628 = add i32 %627, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %629 = lshr exact i32 %628, 4, !dbg !107 + %630 = and i32 %629, 16383, !dbg !107 + %631 = zext nneg i32 %630 to i64, !dbg !107 + %632 = or disjoint i64 %631, 4611686293372403712, !dbg !107 + %633 = ptrtoint ptr addrspace(3) %624 to i32, !dbg !107 + %634 = lshr exact i32 %633, 4, !dbg !107 + %635 = and i32 %634, 16383, !dbg !107 + %636 = zext nneg i32 %635 to i64, !dbg !107 + %637 = or disjoint i64 %636, 4611686293338849280, !dbg !107 + %638 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %632, i64 %637) #3, !dbg !107 + %639 = or disjoint i32 %627, 32, !dbg !107 + %640 = add i32 %639, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %641 = lshr exact i32 %640, 4, !dbg !107 + %642 = and i32 %641, 16383, !dbg !107 + %643 = zext nneg i32 %642 to i64, !dbg !107 + %644 = or disjoint i64 %643, 4611686293372403712, !dbg !107 + %645 = add i32 %633, 32, !dbg !107 + %646 = lshr exact i32 %645, 4, !dbg !107 + %647 = and i32 %646, 16383, !dbg !107 + %648 = zext nneg i32 %647 to i64, !dbg !107 + %649 = or disjoint i64 %648, 4611686293338849280, !dbg !107 + %650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 0, !dbg !107 + %651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 1, !dbg !107 + %652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 2, !dbg !107 + %653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 3, !dbg !107 + %654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 4, !dbg !107 + %655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 5, !dbg !107 + %656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 6, !dbg !107 + %657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 7, !dbg !107 + %658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 8, !dbg !107 + %659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 9, !dbg !107 + %660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 10, !dbg !107 + %661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 11, !dbg !107 + %662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 12, !dbg !107 + %663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 13, !dbg !107 + %664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 14, !dbg !107 + %665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 15, !dbg !107 + %666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 16, !dbg !107 + %667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 17, !dbg !107 + %668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 18, !dbg !107 + %669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 19, !dbg !107 + %670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 20, !dbg !107 + %671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 21, !dbg !107 + %672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 22, !dbg !107 + %673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 23, !dbg !107 + %674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 24, !dbg !107 + %675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 25, !dbg !107 + %676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 26, !dbg !107 + %677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 27, !dbg !107 + %678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 28, !dbg !107 + %679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 29, !dbg !107 + %680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 30, !dbg !107 + %681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %638, 31, !dbg !107 + %682 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %650, float %651, float %652, float %653, float %654, float %655, float %656, float %657, float %658, float %659, float %660, float %661, float %662, float %663, float %664, float %665, float %666, float %667, float %668, float %669, float %670, float %671, float %672, float %673, float %674, float %675, float %676, float %677, float %678, float %679, float %680, float %681, i64 %644, i64 %649, i1 true) #3, !dbg !107 + %683 = or disjoint i32 %627, 64, !dbg !107 + %684 = add i32 %683, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %685 = lshr exact i32 %684, 4, !dbg !107 + %686 = and i32 %685, 16383, !dbg !107 + %687 = zext nneg i32 %686 to i64, !dbg !107 + %688 = or disjoint i64 %687, 4611686293372403712, !dbg !107 + %689 = add i32 %633, 64, !dbg !107 + %690 = lshr exact i32 %689, 4, !dbg !107 + %691 = and i32 %690, 16383, !dbg !107 + %692 = zext nneg i32 %691 to i64, !dbg !107 + %693 = or disjoint i64 %692, 4611686293338849280, !dbg !107 + %694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 0, !dbg !107 + %695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 1, !dbg !107 + %696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 2, !dbg !107 + %697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 3, !dbg !107 + %698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 4, !dbg !107 + %699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 5, !dbg !107 + %700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 6, !dbg !107 + %701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 7, !dbg !107 + %702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 8, !dbg !107 + %703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 9, !dbg !107 + %704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 10, !dbg !107 + %705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 11, !dbg !107 + %706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 12, !dbg !107 + %707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 13, !dbg !107 + %708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 14, !dbg !107 + %709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 15, !dbg !107 + %710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 16, !dbg !107 + %711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 17, !dbg !107 + %712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 18, !dbg !107 + %713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 19, !dbg !107 + %714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 20, !dbg !107 + %715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 21, !dbg !107 + %716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 22, !dbg !107 + %717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 23, !dbg !107 + %718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 24, !dbg !107 + %719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 25, !dbg !107 + %720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 26, !dbg !107 + %721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 27, !dbg !107 + %722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 28, !dbg !107 + %723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 29, !dbg !107 + %724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 30, !dbg !107 + %725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %682, 31, !dbg !107 + %726 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %694, float %695, float %696, float %697, float %698, float %699, float %700, float %701, float %702, float %703, float %704, float %705, float %706, float %707, float %708, float %709, float %710, float %711, float %712, float %713, float %714, float %715, float %716, float %717, float %718, float %719, float %720, float %721, float %722, float %723, float %724, float %725, i64 %688, i64 %693, i1 true) #3, !dbg !107 + %727 = or disjoint i32 %627, 96, !dbg !107 + %728 = add i32 %727, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %729 = lshr exact i32 %728, 4, !dbg !107 + %730 = and i32 %729, 16383, !dbg !107 + %731 = zext nneg i32 %730 to i64, !dbg !107 + %732 = or disjoint i64 %731, 4611686293372403712, !dbg !107 + %733 = add i32 %633, 96, !dbg !107 + %734 = lshr exact i32 %733, 4, !dbg !107 + %735 = and i32 %734, 16383, !dbg !107 + %736 = zext nneg i32 %735 to i64, !dbg !107 + %737 = or disjoint i64 %736, 4611686293338849280, !dbg !107 + %738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 0, !dbg !107 + %739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 1, !dbg !107 + %740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 2, !dbg !107 + %741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 3, !dbg !107 + %742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 4, !dbg !107 + %743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 5, !dbg !107 + %744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 6, !dbg !107 + %745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 7, !dbg !107 + %746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 8, !dbg !107 + %747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 9, !dbg !107 + %748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 10, !dbg !107 + %749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 11, !dbg !107 + %750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 12, !dbg !107 + %751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 13, !dbg !107 + %752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 14, !dbg !107 + %753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 15, !dbg !107 + %754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 16, !dbg !107 + %755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 17, !dbg !107 + %756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 18, !dbg !107 + %757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 19, !dbg !107 + %758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 20, !dbg !107 + %759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 21, !dbg !107 + %760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 22, !dbg !107 + %761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 23, !dbg !107 + %762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 24, !dbg !107 + %763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 25, !dbg !107 + %764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 26, !dbg !107 + %765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 27, !dbg !107 + %766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 28, !dbg !107 + %767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 29, !dbg !107 + %768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 30, !dbg !107 + %769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %726, 31, !dbg !107 + %770 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %738, float %739, float %740, float %741, float %742, float %743, float %744, float %745, float %746, float %747, float %748, float %749, float %750, float %751, float %752, float %753, float %754, float %755, float %756, float %757, float %758, float %759, float %760, float %761, float %762, float %763, float %764, float %765, float %766, float %767, float %768, float %769, i64 %732, i64 %737, i1 true) #3, !dbg !107 + %771 = or disjoint i32 %627, 16384, !dbg !107 + %772 = add i32 %771, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %773 = lshr exact i32 %772, 4, !dbg !107 + %774 = and i32 %773, 16383, !dbg !107 + %775 = zext nneg i32 %774 to i64, !dbg !107 + %776 = or disjoint i64 %775, 4611686293372403712, !dbg !107 + %777 = add i32 %633, 8192, !dbg !107 + %778 = lshr exact i32 %777, 4, !dbg !107 + %779 = and i32 %778, 16383, !dbg !107 + %780 = zext nneg i32 %779 to i64, !dbg !107 + %781 = or disjoint i64 %780, 4611686293338849280, !dbg !107 + %782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 0, !dbg !107 + %783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 1, !dbg !107 + %784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 2, !dbg !107 + %785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 3, !dbg !107 + %786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 4, !dbg !107 + %787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 5, !dbg !107 + %788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 6, !dbg !107 + %789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 7, !dbg !107 + %790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 8, !dbg !107 + %791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 9, !dbg !107 + %792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 10, !dbg !107 + %793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 11, !dbg !107 + %794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 12, !dbg !107 + %795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 13, !dbg !107 + %796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 14, !dbg !107 + %797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 15, !dbg !107 + %798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 16, !dbg !107 + %799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 17, !dbg !107 + %800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 18, !dbg !107 + %801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 19, !dbg !107 + %802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 20, !dbg !107 + %803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 21, !dbg !107 + %804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 22, !dbg !107 + %805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 23, !dbg !107 + %806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 24, !dbg !107 + %807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 25, !dbg !107 + %808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 26, !dbg !107 + %809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 27, !dbg !107 + %810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 28, !dbg !107 + %811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 29, !dbg !107 + %812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 30, !dbg !107 + %813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %770, 31, !dbg !107 + %814 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %782, float %783, float %784, float %785, float %786, float %787, float %788, float %789, float %790, float %791, float %792, float %793, float %794, float %795, float %796, float %797, float %798, float %799, float %800, float %801, float %802, float %803, float %804, float %805, float %806, float %807, float %808, float %809, float %810, float %811, float %812, float %813, i64 %776, i64 %781, i1 true) #3, !dbg !107 + %815 = or disjoint i32 %627, 16416, !dbg !107 + %816 = add i32 %815, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %817 = lshr exact i32 %816, 4, !dbg !107 + %818 = and i32 %817, 16383, !dbg !107 + %819 = zext nneg i32 %818 to i64, !dbg !107 + %820 = or disjoint i64 %819, 4611686293372403712, !dbg !107 + %821 = add i32 %633, 8224, !dbg !107 + %822 = lshr exact i32 %821, 4, !dbg !107 + %823 = and i32 %822, 16383, !dbg !107 + %824 = zext nneg i32 %823 to i64, !dbg !107 + %825 = or disjoint i64 %824, 4611686293338849280, !dbg !107 + %826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 0, !dbg !107 + %827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 1, !dbg !107 + %828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 2, !dbg !107 + %829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 3, !dbg !107 + %830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 4, !dbg !107 + %831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 5, !dbg !107 + %832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 6, !dbg !107 + %833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 7, !dbg !107 + %834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 8, !dbg !107 + %835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 9, !dbg !107 + %836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 10, !dbg !107 + %837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 11, !dbg !107 + %838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 12, !dbg !107 + %839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 13, !dbg !107 + %840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 14, !dbg !107 + %841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 15, !dbg !107 + %842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 16, !dbg !107 + %843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 17, !dbg !107 + %844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 18, !dbg !107 + %845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 19, !dbg !107 + %846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 20, !dbg !107 + %847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 21, !dbg !107 + %848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 22, !dbg !107 + %849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 23, !dbg !107 + %850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 24, !dbg !107 + %851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 25, !dbg !107 + %852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 26, !dbg !107 + %853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 27, !dbg !107 + %854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 28, !dbg !107 + %855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 29, !dbg !107 + %856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 30, !dbg !107 + %857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %814, 31, !dbg !107 + %858 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %826, float %827, float %828, float %829, float %830, float %831, float %832, float %833, float %834, float %835, float %836, float %837, float %838, float %839, float %840, float %841, float %842, float %843, float %844, float %845, float %846, float %847, float %848, float %849, float %850, float %851, float %852, float %853, float %854, float %855, float %856, float %857, i64 %820, i64 %825, i1 true) #3, !dbg !107 + %859 = or disjoint i32 %627, 16448, !dbg !107 + %860 = add i32 %859, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %861 = lshr exact i32 %860, 4, !dbg !107 + %862 = and i32 %861, 16383, !dbg !107 + %863 = zext nneg i32 %862 to i64, !dbg !107 + %864 = or disjoint i64 %863, 4611686293372403712, !dbg !107 + %865 = add i32 %633, 8256, !dbg !107 + %866 = lshr exact i32 %865, 4, !dbg !107 + %867 = and i32 %866, 16383, !dbg !107 + %868 = zext nneg i32 %867 to i64, !dbg !107 + %869 = or disjoint i64 %868, 4611686293338849280, !dbg !107 + %870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 0, !dbg !107 + %871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 1, !dbg !107 + %872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 2, !dbg !107 + %873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 3, !dbg !107 + %874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 4, !dbg !107 + %875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 5, !dbg !107 + %876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 6, !dbg !107 + %877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 7, !dbg !107 + %878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 8, !dbg !107 + %879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 9, !dbg !107 + %880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 10, !dbg !107 + %881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 11, !dbg !107 + %882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 12, !dbg !107 + %883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 13, !dbg !107 + %884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 14, !dbg !107 + %885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 15, !dbg !107 + %886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 16, !dbg !107 + %887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 17, !dbg !107 + %888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 18, !dbg !107 + %889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 19, !dbg !107 + %890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 20, !dbg !107 + %891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 21, !dbg !107 + %892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 22, !dbg !107 + %893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 23, !dbg !107 + %894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 24, !dbg !107 + %895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 25, !dbg !107 + %896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 26, !dbg !107 + %897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 27, !dbg !107 + %898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 28, !dbg !107 + %899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 29, !dbg !107 + %900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 30, !dbg !107 + %901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %858, 31, !dbg !107 + %902 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %870, float %871, float %872, float %873, float %874, float %875, float %876, float %877, float %878, float %879, float %880, float %881, float %882, float %883, float %884, float %885, float %886, float %887, float %888, float %889, float %890, float %891, float %892, float %893, float %894, float %895, float %896, float %897, float %898, float %899, float %900, float %901, i64 %864, i64 %869, i1 true) #3, !dbg !107 + %903 = or disjoint i32 %627, 16480, !dbg !107 + %904 = add i32 %903, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !107 + %905 = lshr exact i32 %904, 4, !dbg !107 + %906 = and i32 %905, 16383, !dbg !107 + %907 = zext nneg i32 %906 to i64, !dbg !107 + %908 = or disjoint i64 %907, 4611686293372403712, !dbg !107 + %909 = add i32 %633, 8288, !dbg !107 + %910 = lshr exact i32 %909, 4, !dbg !107 + %911 = and i32 %910, 16383, !dbg !107 + %912 = zext nneg i32 %911 to i64, !dbg !107 + %913 = or disjoint i64 %912, 4611686293338849280, !dbg !107 + %914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 0, !dbg !107 + %915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 1, !dbg !107 + %916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 2, !dbg !107 + %917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 3, !dbg !107 + %918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 4, !dbg !107 + %919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 5, !dbg !107 + %920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 6, !dbg !107 + %921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 7, !dbg !107 + %922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 8, !dbg !107 + %923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 9, !dbg !107 + %924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 10, !dbg !107 + %925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 11, !dbg !107 + %926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 12, !dbg !107 + %927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 13, !dbg !107 + %928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 14, !dbg !107 + %929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 15, !dbg !107 + %930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 16, !dbg !107 + %931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 17, !dbg !107 + %932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 18, !dbg !107 + %933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 19, !dbg !107 + %934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 20, !dbg !107 + %935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 21, !dbg !107 + %936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 22, !dbg !107 + %937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 23, !dbg !107 + %938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 24, !dbg !107 + %939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 25, !dbg !107 + %940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 26, !dbg !107 + %941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 27, !dbg !107 + %942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 28, !dbg !107 + %943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 29, !dbg !107 + %944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 30, !dbg !107 + %945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %902, 31, !dbg !107 + %946 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %914, float %915, float %916, float %917, float %918, float %919, float %920, float %921, float %922, float %923, float %924, float %925, float %926, float %927, float %928, float %929, float %930, float %931, float %932, float %933, float %934, float %935, float %936, float %937, float %938, float %939, float %940, float %941, float %942, float %943, float %944, float %945, i64 %908, i64 %913, i1 true) #3, !dbg !107 + %947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 0, !dbg !107 + %948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 1, !dbg !107 + %949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 2, !dbg !107 + %950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 3, !dbg !107 + %951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 4, !dbg !107 + %952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 5, !dbg !107 + %953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 6, !dbg !107 + %954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 7, !dbg !107 + %955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 8, !dbg !107 + %956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 9, !dbg !107 + %957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 10, !dbg !107 + %958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 11, !dbg !107 + %959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 12, !dbg !107 + %960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 13, !dbg !107 + %961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 14, !dbg !107 + %962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 15, !dbg !107 + %963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 16, !dbg !107 + %964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 17, !dbg !107 + %965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 18, !dbg !107 + %966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 19, !dbg !107 + %967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 20, !dbg !107 + %968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 21, !dbg !107 + %969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 22, !dbg !107 + %970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 23, !dbg !107 + %971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 24, !dbg !107 + %972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 25, !dbg !107 + %973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 26, !dbg !107 + %974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 27, !dbg !107 + %975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 28, !dbg !107 + %976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 29, !dbg !107 + %977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 30, !dbg !107 + %978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %946, 31, !dbg !107 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !107 + %979 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %947, float %948, float %949, float %950, float %951, float %952, float %953, float %954, float %955, float %956, float %957, float %958, float %959, float %960, float %961, float %962, float %963, float %964, float %965, float %966, float %967, float %968, float %969, float %970, float %971, float %972, float %973, float %974, float %975, float %976, float %977, float %978, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %624, i32 0, i32 0) #3, !dbg !107 + %980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 0, !dbg !107 + %981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 1, !dbg !107 + %982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 2, !dbg !107 + %983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 3, !dbg !107 + %984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 4, !dbg !107 + %985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 5, !dbg !107 + %986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 6, !dbg !107 + %987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 7, !dbg !107 + %988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 8, !dbg !107 + %989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 9, !dbg !107 + %990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 10, !dbg !107 + %991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 11, !dbg !107 + %992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 12, !dbg !107 + %993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 13, !dbg !107 + %994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 14, !dbg !107 + %995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 15, !dbg !107 + %996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 16, !dbg !107 + %997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 17, !dbg !107 + %998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 18, !dbg !107 + %999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 19, !dbg !107 + %1000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 20, !dbg !107 + %1001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 21, !dbg !107 + %1002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 22, !dbg !107 + %1003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 23, !dbg !107 + %1004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 24, !dbg !107 + %1005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 25, !dbg !107 + %1006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 26, !dbg !107 + %1007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 27, !dbg !107 + %1008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 28, !dbg !107 + %1009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 29, !dbg !107 + %1010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 30, !dbg !107 + %1011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %979, 31, !dbg !107 + %1012 = fmul float %980, 0x3FB6A09E60000000, !dbg !119 + %1013 = fmul float %981, 0x3FB6A09E60000000, !dbg !119 + %1014 = fmul float %982, 0x3FB6A09E60000000, !dbg !119 + %1015 = fmul float %983, 0x3FB6A09E60000000, !dbg !119 + %1016 = fmul float %984, 0x3FB6A09E60000000, !dbg !119 + %1017 = fmul float %985, 0x3FB6A09E60000000, !dbg !119 + %1018 = fmul float %986, 0x3FB6A09E60000000, !dbg !119 + %1019 = fmul float %987, 0x3FB6A09E60000000, !dbg !119 + %1020 = fmul float %988, 0x3FB6A09E60000000, !dbg !119 + %1021 = fmul float %989, 0x3FB6A09E60000000, !dbg !119 + %1022 = fmul float %990, 0x3FB6A09E60000000, !dbg !119 + %1023 = fmul float %991, 0x3FB6A09E60000000, !dbg !119 + %1024 = fmul float %992, 0x3FB6A09E60000000, !dbg !119 + %1025 = fmul float %993, 0x3FB6A09E60000000, !dbg !119 + %1026 = fmul float %994, 0x3FB6A09E60000000, !dbg !119 + %1027 = fmul float %995, 0x3FB6A09E60000000, !dbg !119 + %1028 = fmul float %996, 0x3FB6A09E60000000, !dbg !119 + %1029 = fmul float %997, 0x3FB6A09E60000000, !dbg !119 + %1030 = fmul float %998, 0x3FB6A09E60000000, !dbg !119 + %1031 = fmul float %999, 0x3FB6A09E60000000, !dbg !119 + %1032 = fmul float %1000, 0x3FB6A09E60000000, !dbg !119 + %1033 = fmul float %1001, 0x3FB6A09E60000000, !dbg !119 + %1034 = fmul float %1002, 0x3FB6A09E60000000, !dbg !119 + %1035 = fmul float %1003, 0x3FB6A09E60000000, !dbg !119 + %1036 = fmul float %1004, 0x3FB6A09E60000000, !dbg !119 + %1037 = fmul float %1005, 0x3FB6A09E60000000, !dbg !119 + %1038 = fmul float %1006, 0x3FB6A09E60000000, !dbg !119 + %1039 = fmul float %1007, 0x3FB6A09E60000000, !dbg !119 + %1040 = fmul float %1008, 0x3FB6A09E60000000, !dbg !119 + %1041 = fmul float %1009, 0x3FB6A09E60000000, !dbg !119 + %1042 = fmul float %1010, 0x3FB6A09E60000000, !dbg !119 + %1043 = fmul float %1011, 0x3FB6A09E60000000, !dbg !119 + %1044 = srem <2 x i32> %609, %520, !dbg !108 + %1045 = icmp sle <2 x i32> %1044, %514, !dbg !115 + %1046 = and <2 x i1> %512, %1045, !dbg !114 + %1047 = icmp slt <2 x i32> %1044, zeroinitializer, !dbg !120 + %1048 = and <2 x i1> %516, %1047, !dbg !116 + %1049 = or <2 x i32> %1044, %514, !dbg !121 + %1050 = icmp sgt <2 x i32> %1049, splat (i32 -1), !dbg !121 + %1051 = and <2 x i32> %1044, splat (i32 15), !dbg !122 + %1052 = icmp ne <2 x i32> %1051, zeroinitializer, !dbg !122 + %1053 = sdiv <2 x i32> %1044, splat (i32 16), !dbg !123 + %1054 = and <2 x i1> %1047, %1052, !dbg !124 + %1055 = sext <2 x i1> %1054 to <2 x i32>, !dbg !124 + %1056 = add nsw <2 x i32> %1053, %1055, !dbg !124 + %1057 = icmp eq <2 x i32> %518, %1056, !dbg !117 + %1058 = and <2 x i1> %1050, %1057, !dbg !125 + %1059 = or <2 x i1> %1048, %1058, !dbg !126 + %1060 = or <2 x i1> %1046, %1059, !dbg !127 + %1061 = icmp sle <2 x i32> %1044, %526, !dbg !115 + %1062 = and <2 x i1> %524, %1061, !dbg !114 + %1063 = and <2 x i1> %528, %1047, !dbg !116 + %1064 = or <2 x i32> %1044, %526, !dbg !121 + %1065 = icmp sgt <2 x i32> %1064, splat (i32 -1), !dbg !121 + %1066 = icmp eq <2 x i32> %530, %1056, !dbg !117 + %1067 = and <2 x i1> %1065, %1066, !dbg !125 + %1068 = or <2 x i1> %1063, %1067, !dbg !126 + %1069 = or <2 x i1> %1062, %1068, !dbg !127 + %1070 = select <2 x i1> %1069, <2 x i1> %615, <2 x i1> zeroinitializer, !dbg !128 + %1071 = select <2 x i1> %1060, <2 x i1> %615, <2 x i1> zeroinitializer, !dbg !128 + %1072 = srem <2 x i32> %608, %520, !dbg !108 + %1073 = icmp sle <2 x i32> %1072, %514, !dbg !115 + %1074 = and <2 x i1> %512, %1073, !dbg !114 + %1075 = icmp slt <2 x i32> %1072, zeroinitializer, !dbg !120 + %1076 = and <2 x i1> %516, %1075, !dbg !116 + %1077 = or <2 x i32> %1072, %514, !dbg !121 + %1078 = icmp sgt <2 x i32> %1077, splat (i32 -1), !dbg !121 + %1079 = and <2 x i32> %1072, splat (i32 15), !dbg !122 + %1080 = icmp ne <2 x i32> %1079, zeroinitializer, !dbg !122 + %1081 = sdiv <2 x i32> %1072, splat (i32 16), !dbg !123 + %1082 = and <2 x i1> %1075, %1080, !dbg !124 + %1083 = sext <2 x i1> %1082 to <2 x i32>, !dbg !124 + %1084 = add nsw <2 x i32> %1081, %1083, !dbg !124 + %1085 = icmp eq <2 x i32> %518, %1084, !dbg !117 + %1086 = and <2 x i1> %1078, %1085, !dbg !125 + %1087 = or <2 x i1> %1076, %1086, !dbg !126 + %1088 = or <2 x i1> %1074, %1087, !dbg !127 + %1089 = icmp sle <2 x i32> %1072, %526, !dbg !115 + %1090 = and <2 x i1> %524, %1089, !dbg !114 + %1091 = and <2 x i1> %528, %1075, !dbg !116 + %1092 = or <2 x i32> %1072, %526, !dbg !121 + %1093 = icmp sgt <2 x i32> %1092, splat (i32 -1), !dbg !121 + %1094 = icmp eq <2 x i32> %530, %1084, !dbg !117 + %1095 = and <2 x i1> %1093, %1094, !dbg !125 + %1096 = or <2 x i1> %1091, %1095, !dbg !126 + %1097 = or <2 x i1> %1090, %1096, !dbg !127 + %1098 = select <2 x i1> %1097, <2 x i1> %616, <2 x i1> zeroinitializer, !dbg !128 + %1099 = select <2 x i1> %1088, <2 x i1> %616, <2 x i1> zeroinitializer, !dbg !128 + %1100 = srem <2 x i32> %607, %520, !dbg !108 + %1101 = icmp sle <2 x i32> %1100, %514, !dbg !115 + %1102 = and <2 x i1> %512, %1101, !dbg !114 + %1103 = icmp slt <2 x i32> %1100, zeroinitializer, !dbg !120 + %1104 = and <2 x i1> %516, %1103, !dbg !116 + %1105 = or <2 x i32> %1100, %514, !dbg !121 + %1106 = icmp sgt <2 x i32> %1105, splat (i32 -1), !dbg !121 + %1107 = and <2 x i32> %1100, splat (i32 15), !dbg !122 + %1108 = icmp ne <2 x i32> %1107, zeroinitializer, !dbg !122 + %1109 = sdiv <2 x i32> %1100, splat (i32 16), !dbg !123 + %1110 = and <2 x i1> %1103, %1108, !dbg !124 + %1111 = sext <2 x i1> %1110 to <2 x i32>, !dbg !124 + %1112 = add nsw <2 x i32> %1109, %1111, !dbg !124 + %1113 = icmp eq <2 x i32> %518, %1112, !dbg !117 + %1114 = and <2 x i1> %1106, %1113, !dbg !125 + %1115 = or <2 x i1> %1104, %1114, !dbg !126 + %1116 = or <2 x i1> %1102, %1115, !dbg !127 + %1117 = icmp sle <2 x i32> %1100, %526, !dbg !115 + %1118 = and <2 x i1> %524, %1117, !dbg !114 + %1119 = and <2 x i1> %528, %1103, !dbg !116 + %1120 = or <2 x i32> %1100, %526, !dbg !121 + %1121 = icmp sgt <2 x i32> %1120, splat (i32 -1), !dbg !121 + %1122 = icmp eq <2 x i32> %530, %1112, !dbg !117 + %1123 = and <2 x i1> %1121, %1122, !dbg !125 + %1124 = or <2 x i1> %1119, %1123, !dbg !126 + %1125 = or <2 x i1> %1118, %1124, !dbg !127 + %1126 = select <2 x i1> %1125, <2 x i1> %617, <2 x i1> zeroinitializer, !dbg !128 + %1127 = select <2 x i1> %1116, <2 x i1> %617, <2 x i1> zeroinitializer, !dbg !128 + %1128 = srem <2 x i32> %606, %520, !dbg !108 + %1129 = icmp sle <2 x i32> %1128, %514, !dbg !115 + %1130 = and <2 x i1> %512, %1129, !dbg !114 + %1131 = icmp slt <2 x i32> %1128, zeroinitializer, !dbg !120 + %1132 = and <2 x i1> %516, %1131, !dbg !116 + %1133 = or <2 x i32> %1128, %514, !dbg !121 + %1134 = icmp sgt <2 x i32> %1133, splat (i32 -1), !dbg !121 + %1135 = and <2 x i32> %1128, splat (i32 15), !dbg !122 + %1136 = icmp ne <2 x i32> %1135, zeroinitializer, !dbg !122 + %1137 = sdiv <2 x i32> %1128, splat (i32 16), !dbg !123 + %1138 = and <2 x i1> %1131, %1136, !dbg !124 + %1139 = sext <2 x i1> %1138 to <2 x i32>, !dbg !124 + %1140 = add nsw <2 x i32> %1137, %1139, !dbg !124 + %1141 = icmp eq <2 x i32> %518, %1140, !dbg !117 + %1142 = and <2 x i1> %1134, %1141, !dbg !125 + %1143 = or <2 x i1> %1132, %1142, !dbg !126 + %1144 = or <2 x i1> %1130, %1143, !dbg !127 + %1145 = icmp sle <2 x i32> %1128, %526, !dbg !115 + %1146 = and <2 x i1> %524, %1145, !dbg !114 + %1147 = and <2 x i1> %528, %1131, !dbg !116 + %1148 = or <2 x i32> %1128, %526, !dbg !121 + %1149 = icmp sgt <2 x i32> %1148, splat (i32 -1), !dbg !121 + %1150 = icmp eq <2 x i32> %530, %1140, !dbg !117 + %1151 = and <2 x i1> %1149, %1150, !dbg !125 + %1152 = or <2 x i1> %1147, %1151, !dbg !126 + %1153 = or <2 x i1> %1146, %1152, !dbg !127 + %1154 = select <2 x i1> %1153, <2 x i1> %618, <2 x i1> zeroinitializer, !dbg !128 + %1155 = select <2 x i1> %1144, <2 x i1> %618, <2 x i1> zeroinitializer, !dbg !128 + %1156 = srem <2 x i32> %605, %520, !dbg !108 + %1157 = icmp sle <2 x i32> %1156, %514, !dbg !115 + %1158 = and <2 x i1> %512, %1157, !dbg !114 + %1159 = icmp slt <2 x i32> %1156, zeroinitializer, !dbg !120 + %1160 = and <2 x i1> %516, %1159, !dbg !116 + %1161 = or <2 x i32> %1156, %514, !dbg !121 + %1162 = icmp sgt <2 x i32> %1161, splat (i32 -1), !dbg !121 + %1163 = and <2 x i32> %1156, splat (i32 15), !dbg !122 + %1164 = icmp ne <2 x i32> %1163, zeroinitializer, !dbg !122 + %1165 = sdiv <2 x i32> %1156, splat (i32 16), !dbg !123 + %1166 = and <2 x i1> %1159, %1164, !dbg !124 + %1167 = sext <2 x i1> %1166 to <2 x i32>, !dbg !124 + %1168 = add nsw <2 x i32> %1165, %1167, !dbg !124 + %1169 = icmp eq <2 x i32> %518, %1168, !dbg !117 + %1170 = and <2 x i1> %1162, %1169, !dbg !125 + %1171 = or <2 x i1> %1160, %1170, !dbg !126 + %1172 = or <2 x i1> %1158, %1171, !dbg !127 + %1173 = icmp sle <2 x i32> %1156, %526, !dbg !115 + %1174 = and <2 x i1> %524, %1173, !dbg !114 + %1175 = and <2 x i1> %528, %1159, !dbg !116 + %1176 = or <2 x i32> %1156, %526, !dbg !121 + %1177 = icmp sgt <2 x i32> %1176, splat (i32 -1), !dbg !121 + %1178 = icmp eq <2 x i32> %530, %1168, !dbg !117 + %1179 = and <2 x i1> %1177, %1178, !dbg !125 + %1180 = or <2 x i1> %1175, %1179, !dbg !126 + %1181 = or <2 x i1> %1174, %1180, !dbg !127 + %1182 = select <2 x i1> %1181, <2 x i1> %619, <2 x i1> zeroinitializer, !dbg !128 + %1183 = select <2 x i1> %1172, <2 x i1> %619, <2 x i1> zeroinitializer, !dbg !128 + %1184 = srem <2 x i32> %604, %520, !dbg !108 + %1185 = icmp sle <2 x i32> %1184, %514, !dbg !115 + %1186 = and <2 x i1> %512, %1185, !dbg !114 + %1187 = icmp slt <2 x i32> %1184, zeroinitializer, !dbg !120 + %1188 = and <2 x i1> %516, %1187, !dbg !116 + %1189 = or <2 x i32> %1184, %514, !dbg !121 + %1190 = icmp sgt <2 x i32> %1189, splat (i32 -1), !dbg !121 + %1191 = and <2 x i32> %1184, splat (i32 15), !dbg !122 + %1192 = icmp ne <2 x i32> %1191, zeroinitializer, !dbg !122 + %1193 = sdiv <2 x i32> %1184, splat (i32 16), !dbg !123 + %1194 = and <2 x i1> %1187, %1192, !dbg !124 + %1195 = sext <2 x i1> %1194 to <2 x i32>, !dbg !124 + %1196 = add nsw <2 x i32> %1193, %1195, !dbg !124 + %1197 = icmp eq <2 x i32> %518, %1196, !dbg !117 + %1198 = and <2 x i1> %1190, %1197, !dbg !125 + %1199 = or <2 x i1> %1188, %1198, !dbg !126 + %1200 = or <2 x i1> %1186, %1199, !dbg !127 + %1201 = icmp sle <2 x i32> %1184, %526, !dbg !115 + %1202 = and <2 x i1> %524, %1201, !dbg !114 + %1203 = and <2 x i1> %528, %1187, !dbg !116 + %1204 = or <2 x i32> %1184, %526, !dbg !121 + %1205 = icmp sgt <2 x i32> %1204, splat (i32 -1), !dbg !121 + %1206 = icmp eq <2 x i32> %530, %1196, !dbg !117 + %1207 = and <2 x i1> %1205, %1206, !dbg !125 + %1208 = or <2 x i1> %1203, %1207, !dbg !126 + %1209 = or <2 x i1> %1202, %1208, !dbg !127 + %1210 = select <2 x i1> %1209, <2 x i1> %620, <2 x i1> zeroinitializer, !dbg !128 + %1211 = select <2 x i1> %1200, <2 x i1> %620, <2 x i1> zeroinitializer, !dbg !128 + %1212 = srem <2 x i32> %603, %520, !dbg !108 + %1213 = icmp sle <2 x i32> %1212, %514, !dbg !115 + %1214 = and <2 x i1> %512, %1213, !dbg !114 + %1215 = icmp slt <2 x i32> %1212, zeroinitializer, !dbg !120 + %1216 = and <2 x i1> %516, %1215, !dbg !116 + %1217 = or <2 x i32> %1212, %514, !dbg !121 + %1218 = icmp sgt <2 x i32> %1217, splat (i32 -1), !dbg !121 + %1219 = and <2 x i32> %1212, splat (i32 15), !dbg !122 + %1220 = icmp ne <2 x i32> %1219, zeroinitializer, !dbg !122 + %1221 = sdiv <2 x i32> %1212, splat (i32 16), !dbg !123 + %1222 = and <2 x i1> %1215, %1220, !dbg !124 + %1223 = sext <2 x i1> %1222 to <2 x i32>, !dbg !124 + %1224 = add nsw <2 x i32> %1221, %1223, !dbg !124 + %1225 = icmp eq <2 x i32> %518, %1224, !dbg !117 + %1226 = and <2 x i1> %1218, %1225, !dbg !125 + %1227 = or <2 x i1> %1216, %1226, !dbg !126 + %1228 = or <2 x i1> %1214, %1227, !dbg !127 + %1229 = icmp sle <2 x i32> %1212, %526, !dbg !115 + %1230 = and <2 x i1> %524, %1229, !dbg !114 + %1231 = and <2 x i1> %528, %1215, !dbg !116 + %1232 = or <2 x i32> %1212, %526, !dbg !121 + %1233 = icmp sgt <2 x i32> %1232, splat (i32 -1), !dbg !121 + %1234 = icmp eq <2 x i32> %530, %1224, !dbg !117 + %1235 = and <2 x i1> %1233, %1234, !dbg !125 + %1236 = or <2 x i1> %1231, %1235, !dbg !126 + %1237 = or <2 x i1> %1230, %1236, !dbg !127 + %1238 = select <2 x i1> %1237, <2 x i1> %621, <2 x i1> zeroinitializer, !dbg !128 + %1239 = select <2 x i1> %1228, <2 x i1> %621, <2 x i1> zeroinitializer, !dbg !128 + %1240 = srem <2 x i32> %602, %520, !dbg !108 + %1241 = icmp sle <2 x i32> %1240, %514, !dbg !115 + %1242 = and <2 x i1> %512, %1241, !dbg !114 + %1243 = icmp slt <2 x i32> %1240, zeroinitializer, !dbg !120 + %1244 = and <2 x i1> %516, %1243, !dbg !116 + %1245 = or <2 x i32> %1240, %514, !dbg !121 + %1246 = icmp sgt <2 x i32> %1245, splat (i32 -1), !dbg !121 + %1247 = and <2 x i32> %1240, splat (i32 15), !dbg !122 + %1248 = icmp ne <2 x i32> %1247, zeroinitializer, !dbg !122 + %1249 = sdiv <2 x i32> %1240, splat (i32 16), !dbg !123 + %1250 = and <2 x i1> %1243, %1248, !dbg !124 + %1251 = sext <2 x i1> %1250 to <2 x i32>, !dbg !124 + %1252 = add nsw <2 x i32> %1249, %1251, !dbg !124 + %1253 = icmp eq <2 x i32> %518, %1252, !dbg !117 + %1254 = and <2 x i1> %1246, %1253, !dbg !125 + %1255 = or <2 x i1> %1244, %1254, !dbg !126 + %1256 = or <2 x i1> %1242, %1255, !dbg !127 + %1257 = icmp sle <2 x i32> %1240, %526, !dbg !115 + %1258 = and <2 x i1> %524, %1257, !dbg !114 + %1259 = and <2 x i1> %528, %1243, !dbg !116 + %1260 = or <2 x i32> %1240, %526, !dbg !121 + %1261 = icmp sgt <2 x i32> %1260, splat (i32 -1), !dbg !121 + %1262 = icmp eq <2 x i32> %530, %1252, !dbg !117 + %1263 = and <2 x i1> %1261, %1262, !dbg !125 + %1264 = or <2 x i1> %1259, %1263, !dbg !126 + %1265 = or <2 x i1> %1258, %1264, !dbg !127 + %1266 = select <2 x i1> %1265, <2 x i1> %622, <2 x i1> zeroinitializer, !dbg !128 + %1267 = select <2 x i1> %1256, <2 x i1> %622, <2 x i1> zeroinitializer, !dbg !128 + %1268 = fmul float %1012, 0x3FF7154760000000, !dbg !129 + %1269 = extractelement <2 x i1> %1070, i64 0, !dbg !128 + %1270 = select i1 %1269, float %1268, float 0xFFF0000000000000, !dbg !128 + %1271 = fmul float %1013, 0x3FF7154760000000, !dbg !129 + %1272 = extractelement <2 x i1> %1070, i64 1, !dbg !128 + %1273 = select i1 %1272, float %1271, float 0xFFF0000000000000, !dbg !128 + %1274 = fmul float %1014, 0x3FF7154760000000, !dbg !129 + %1275 = extractelement <2 x i1> %1071, i64 0, !dbg !128 + %1276 = select i1 %1275, float %1274, float 0xFFF0000000000000, !dbg !128 + %1277 = fmul float %1015, 0x3FF7154760000000, !dbg !129 + %1278 = extractelement <2 x i1> %1071, i64 1, !dbg !128 + %1279 = select i1 %1278, float %1277, float 0xFFF0000000000000, !dbg !128 + %1280 = fmul float %1016, 0x3FF7154760000000, !dbg !129 + %1281 = extractelement <2 x i1> %1098, i64 0, !dbg !128 + %1282 = select i1 %1281, float %1280, float 0xFFF0000000000000, !dbg !128 + %1283 = fmul float %1017, 0x3FF7154760000000, !dbg !129 + %1284 = extractelement <2 x i1> %1098, i64 1, !dbg !128 + %1285 = select i1 %1284, float %1283, float 0xFFF0000000000000, !dbg !128 + %1286 = fmul float %1018, 0x3FF7154760000000, !dbg !129 + %1287 = extractelement <2 x i1> %1099, i64 0, !dbg !128 + %1288 = select i1 %1287, float %1286, float 0xFFF0000000000000, !dbg !128 + %1289 = fmul float %1019, 0x3FF7154760000000, !dbg !129 + %1290 = extractelement <2 x i1> %1099, i64 1, !dbg !128 + %1291 = select i1 %1290, float %1289, float 0xFFF0000000000000, !dbg !128 + %1292 = fmul float %1020, 0x3FF7154760000000, !dbg !129 + %1293 = extractelement <2 x i1> %1126, i64 0, !dbg !128 + %1294 = select i1 %1293, float %1292, float 0xFFF0000000000000, !dbg !128 + %1295 = fmul float %1021, 0x3FF7154760000000, !dbg !129 + %1296 = extractelement <2 x i1> %1126, i64 1, !dbg !128 + %1297 = select i1 %1296, float %1295, float 0xFFF0000000000000, !dbg !128 + %1298 = fmul float %1022, 0x3FF7154760000000, !dbg !129 + %1299 = extractelement <2 x i1> %1127, i64 0, !dbg !128 + %1300 = select i1 %1299, float %1298, float 0xFFF0000000000000, !dbg !128 + %1301 = fmul float %1023, 0x3FF7154760000000, !dbg !129 + %1302 = extractelement <2 x i1> %1127, i64 1, !dbg !128 + %1303 = select i1 %1302, float %1301, float 0xFFF0000000000000, !dbg !128 + %1304 = fmul float %1024, 0x3FF7154760000000, !dbg !129 + %1305 = extractelement <2 x i1> %1154, i64 0, !dbg !128 + %1306 = select i1 %1305, float %1304, float 0xFFF0000000000000, !dbg !128 + %1307 = fmul float %1025, 0x3FF7154760000000, !dbg !129 + %1308 = extractelement <2 x i1> %1154, i64 1, !dbg !128 + %1309 = select i1 %1308, float %1307, float 0xFFF0000000000000, !dbg !128 + %1310 = fmul float %1026, 0x3FF7154760000000, !dbg !129 + %1311 = extractelement <2 x i1> %1155, i64 0, !dbg !128 + %1312 = select i1 %1311, float %1310, float 0xFFF0000000000000, !dbg !128 + %1313 = fmul float %1027, 0x3FF7154760000000, !dbg !129 + %1314 = extractelement <2 x i1> %1155, i64 1, !dbg !128 + %1315 = select i1 %1314, float %1313, float 0xFFF0000000000000, !dbg !128 + %1316 = fmul float %1028, 0x3FF7154760000000, !dbg !129 + %1317 = extractelement <2 x i1> %1182, i64 0, !dbg !128 + %1318 = select i1 %1317, float %1316, float 0xFFF0000000000000, !dbg !128 + %1319 = fmul float %1029, 0x3FF7154760000000, !dbg !129 + %1320 = extractelement <2 x i1> %1182, i64 1, !dbg !128 + %1321 = select i1 %1320, float %1319, float 0xFFF0000000000000, !dbg !128 + %1322 = fmul float %1030, 0x3FF7154760000000, !dbg !129 + %1323 = extractelement <2 x i1> %1183, i64 0, !dbg !128 + %1324 = select i1 %1323, float %1322, float 0xFFF0000000000000, !dbg !128 + %1325 = fmul float %1031, 0x3FF7154760000000, !dbg !129 + %1326 = extractelement <2 x i1> %1183, i64 1, !dbg !128 + %1327 = select i1 %1326, float %1325, float 0xFFF0000000000000, !dbg !128 + %1328 = fmul float %1032, 0x3FF7154760000000, !dbg !129 + %1329 = extractelement <2 x i1> %1210, i64 0, !dbg !128 + %1330 = select i1 %1329, float %1328, float 0xFFF0000000000000, !dbg !128 + %1331 = fmul float %1033, 0x3FF7154760000000, !dbg !129 + %1332 = extractelement <2 x i1> %1210, i64 1, !dbg !128 + %1333 = select i1 %1332, float %1331, float 0xFFF0000000000000, !dbg !128 + %1334 = fmul float %1034, 0x3FF7154760000000, !dbg !129 + %1335 = extractelement <2 x i1> %1211, i64 0, !dbg !128 + %1336 = select i1 %1335, float %1334, float 0xFFF0000000000000, !dbg !128 + %1337 = fmul float %1035, 0x3FF7154760000000, !dbg !129 + %1338 = extractelement <2 x i1> %1211, i64 1, !dbg !128 + %1339 = select i1 %1338, float %1337, float 0xFFF0000000000000, !dbg !128 + %1340 = fmul float %1036, 0x3FF7154760000000, !dbg !129 + %1341 = extractelement <2 x i1> %1238, i64 0, !dbg !128 + %1342 = select i1 %1341, float %1340, float 0xFFF0000000000000, !dbg !128 + %1343 = fmul float %1037, 0x3FF7154760000000, !dbg !129 + %1344 = extractelement <2 x i1> %1238, i64 1, !dbg !128 + %1345 = select i1 %1344, float %1343, float 0xFFF0000000000000, !dbg !128 + %1346 = fmul float %1038, 0x3FF7154760000000, !dbg !129 + %1347 = extractelement <2 x i1> %1239, i64 0, !dbg !128 + %1348 = select i1 %1347, float %1346, float 0xFFF0000000000000, !dbg !128 + %1349 = fmul float %1039, 0x3FF7154760000000, !dbg !129 + %1350 = extractelement <2 x i1> %1239, i64 1, !dbg !128 + %1351 = select i1 %1350, float %1349, float 0xFFF0000000000000, !dbg !128 + %1352 = fmul float %1040, 0x3FF7154760000000, !dbg !129 + %1353 = extractelement <2 x i1> %1266, i64 0, !dbg !128 + %1354 = select i1 %1353, float %1352, float 0xFFF0000000000000, !dbg !128 + %1355 = fmul float %1041, 0x3FF7154760000000, !dbg !129 + %1356 = extractelement <2 x i1> %1266, i64 1, !dbg !128 + %1357 = select i1 %1356, float %1355, float 0xFFF0000000000000, !dbg !128 + %1358 = fmul float %1042, 0x3FF7154760000000, !dbg !129 + %1359 = extractelement <2 x i1> %1267, i64 0, !dbg !128 + %1360 = select i1 %1359, float %1358, float 0xFFF0000000000000, !dbg !128 + %1361 = fmul float %1043, 0x3FF7154760000000, !dbg !129 + %1362 = extractelement <2 x i1> %1267, i64 1, !dbg !128 + %1363 = select i1 %1362, float %1361, float 0xFFF0000000000000, !dbg !128 + %1364 = fsub float %1270, %366, !dbg !130 + %1365 = fsub float %1273, %366, !dbg !130 + %1366 = fsub float %1276, %367, !dbg !130 + %1367 = fsub float %1279, %367, !dbg !130 + %1368 = fsub float %1282, %366, !dbg !130 + %1369 = fsub float %1285, %366, !dbg !130 + %1370 = fsub float %1288, %367, !dbg !130 + %1371 = fsub float %1291, %367, !dbg !130 + %1372 = fsub float %1294, %366, !dbg !130 + %1373 = fsub float %1297, %366, !dbg !130 + %1374 = fsub float %1300, %367, !dbg !130 + %1375 = fsub float %1303, %367, !dbg !130 + %1376 = fsub float %1306, %366, !dbg !130 + %1377 = fsub float %1309, %366, !dbg !130 + %1378 = fsub float %1312, %367, !dbg !130 + %1379 = fsub float %1315, %367, !dbg !130 + %1380 = fsub float %1318, %366, !dbg !130 + %1381 = fsub float %1321, %366, !dbg !130 + %1382 = fsub float %1324, %367, !dbg !130 + %1383 = fsub float %1327, %367, !dbg !130 + %1384 = fsub float %1330, %366, !dbg !130 + %1385 = fsub float %1333, %366, !dbg !130 + %1386 = fsub float %1336, %367, !dbg !130 + %1387 = fsub float %1339, %367, !dbg !130 + %1388 = fsub float %1342, %366, !dbg !130 + %1389 = fsub float %1345, %366, !dbg !130 + %1390 = fsub float %1348, %367, !dbg !130 + %1391 = fsub float %1351, %367, !dbg !130 + %1392 = fsub float %1354, %366, !dbg !130 + %1393 = fsub float %1357, %366, !dbg !130 + %1394 = fsub float %1360, %367, !dbg !130 + %1395 = fsub float %1363, %367, !dbg !130 + %1396 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1514 = icmp eq i32 %1396, 0, !dbg !131 + br i1 %.not.i1514, label %1399, label %1397, !dbg !131 + +1397: ; preds = %533 + %1398 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1364) #3, !dbg !131 + br label %__nv_exp2f.exit1516, !dbg !131 + +1399: ; preds = %533 + %1400 = tail call float @llvm.nvvm.ex2.approx.f(float %1364) #3, !dbg !131 + br label %__nv_exp2f.exit1516, !dbg !131 + +__nv_exp2f.exit1516: ; preds = %1397, %1399 + %.0.i1515 = phi float [ %1398, %1397 ], [ %1400, %1399 ], !dbg !131 + %1401 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1517 = icmp eq i32 %1401, 0, !dbg !131 + br i1 %.not.i1517, label %1404, label %1402, !dbg !131 + +1402: ; preds = %__nv_exp2f.exit1516 + %1403 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1365) #3, !dbg !131 + br label %__nv_exp2f.exit1519, !dbg !131 + +1404: ; preds = %__nv_exp2f.exit1516 + %1405 = tail call float @llvm.nvvm.ex2.approx.f(float %1365) #3, !dbg !131 + br label %__nv_exp2f.exit1519, !dbg !131 + +__nv_exp2f.exit1519: ; preds = %1402, %1404 + %.0.i1518 = phi float [ %1403, %1402 ], [ %1405, %1404 ], !dbg !131 + %1406 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1520 = icmp eq i32 %1406, 0, !dbg !131 + br i1 %.not.i1520, label %1409, label %1407, !dbg !131 + +1407: ; preds = %__nv_exp2f.exit1519 + %1408 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1366) #3, !dbg !131 + br label %__nv_exp2f.exit1522, !dbg !131 + +1409: ; preds = %__nv_exp2f.exit1519 + %1410 = tail call float @llvm.nvvm.ex2.approx.f(float %1366) #3, !dbg !131 + br label %__nv_exp2f.exit1522, !dbg !131 + +__nv_exp2f.exit1522: ; preds = %1407, %1409 + %.0.i1521 = phi float [ %1408, %1407 ], [ %1410, %1409 ], !dbg !131 + %1411 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1523 = icmp eq i32 %1411, 0, !dbg !131 + br i1 %.not.i1523, label %1414, label %1412, !dbg !131 + +1412: ; preds = %__nv_exp2f.exit1522 + %1413 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1367) #3, !dbg !131 + br label %__nv_exp2f.exit1525, !dbg !131 + +1414: ; preds = %__nv_exp2f.exit1522 + %1415 = tail call float @llvm.nvvm.ex2.approx.f(float %1367) #3, !dbg !131 + br label %__nv_exp2f.exit1525, !dbg !131 + +__nv_exp2f.exit1525: ; preds = %1412, %1414 + %.0.i1524 = phi float [ %1413, %1412 ], [ %1415, %1414 ], !dbg !131 + %1416 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1526 = icmp eq i32 %1416, 0, !dbg !131 + br i1 %.not.i1526, label %1419, label %1417, !dbg !131 + +1417: ; preds = %__nv_exp2f.exit1525 + %1418 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1368) #3, !dbg !131 + br label %__nv_exp2f.exit1528, !dbg !131 + +1419: ; preds = %__nv_exp2f.exit1525 + %1420 = tail call float @llvm.nvvm.ex2.approx.f(float %1368) #3, !dbg !131 + br label %__nv_exp2f.exit1528, !dbg !131 + +__nv_exp2f.exit1528: ; preds = %1417, %1419 + %.0.i1527 = phi float [ %1418, %1417 ], [ %1420, %1419 ], !dbg !131 + %1421 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1529 = icmp eq i32 %1421, 0, !dbg !131 + br i1 %.not.i1529, label %1424, label %1422, !dbg !131 + +1422: ; preds = %__nv_exp2f.exit1528 + %1423 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1369) #3, !dbg !131 + br label %__nv_exp2f.exit1531, !dbg !131 + +1424: ; preds = %__nv_exp2f.exit1528 + %1425 = tail call float @llvm.nvvm.ex2.approx.f(float %1369) #3, !dbg !131 + br label %__nv_exp2f.exit1531, !dbg !131 + +__nv_exp2f.exit1531: ; preds = %1422, %1424 + %.0.i1530 = phi float [ %1423, %1422 ], [ %1425, %1424 ], !dbg !131 + %1426 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1532 = icmp eq i32 %1426, 0, !dbg !131 + br i1 %.not.i1532, label %1429, label %1427, !dbg !131 + +1427: ; preds = %__nv_exp2f.exit1531 + %1428 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1370) #3, !dbg !131 + br label %__nv_exp2f.exit1534, !dbg !131 + +1429: ; preds = %__nv_exp2f.exit1531 + %1430 = tail call float @llvm.nvvm.ex2.approx.f(float %1370) #3, !dbg !131 + br label %__nv_exp2f.exit1534, !dbg !131 + +__nv_exp2f.exit1534: ; preds = %1427, %1429 + %.0.i1533 = phi float [ %1428, %1427 ], [ %1430, %1429 ], !dbg !131 + %1431 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1535 = icmp eq i32 %1431, 0, !dbg !131 + br i1 %.not.i1535, label %1434, label %1432, !dbg !131 + +1432: ; preds = %__nv_exp2f.exit1534 + %1433 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1371) #3, !dbg !131 + br label %__nv_exp2f.exit1537, !dbg !131 + +1434: ; preds = %__nv_exp2f.exit1534 + %1435 = tail call float @llvm.nvvm.ex2.approx.f(float %1371) #3, !dbg !131 + br label %__nv_exp2f.exit1537, !dbg !131 + +__nv_exp2f.exit1537: ; preds = %1432, %1434 + %.0.i1536 = phi float [ %1433, %1432 ], [ %1435, %1434 ], !dbg !131 + %1436 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1538 = icmp eq i32 %1436, 0, !dbg !131 + br i1 %.not.i1538, label %1439, label %1437, !dbg !131 + +1437: ; preds = %__nv_exp2f.exit1537 + %1438 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1372) #3, !dbg !131 + br label %__nv_exp2f.exit1540, !dbg !131 + +1439: ; preds = %__nv_exp2f.exit1537 + %1440 = tail call float @llvm.nvvm.ex2.approx.f(float %1372) #3, !dbg !131 + br label %__nv_exp2f.exit1540, !dbg !131 + +__nv_exp2f.exit1540: ; preds = %1437, %1439 + %.0.i1539 = phi float [ %1438, %1437 ], [ %1440, %1439 ], !dbg !131 + %1441 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1541 = icmp eq i32 %1441, 0, !dbg !131 + br i1 %.not.i1541, label %1444, label %1442, !dbg !131 + +1442: ; preds = %__nv_exp2f.exit1540 + %1443 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1373) #3, !dbg !131 + br label %__nv_exp2f.exit1543, !dbg !131 + +1444: ; preds = %__nv_exp2f.exit1540 + %1445 = tail call float @llvm.nvvm.ex2.approx.f(float %1373) #3, !dbg !131 + br label %__nv_exp2f.exit1543, !dbg !131 + +__nv_exp2f.exit1543: ; preds = %1442, %1444 + %.0.i1542 = phi float [ %1443, %1442 ], [ %1445, %1444 ], !dbg !131 + %1446 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1544 = icmp eq i32 %1446, 0, !dbg !131 + br i1 %.not.i1544, label %1449, label %1447, !dbg !131 + +1447: ; preds = %__nv_exp2f.exit1543 + %1448 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1374) #3, !dbg !131 + br label %__nv_exp2f.exit1546, !dbg !131 + +1449: ; preds = %__nv_exp2f.exit1543 + %1450 = tail call float @llvm.nvvm.ex2.approx.f(float %1374) #3, !dbg !131 + br label %__nv_exp2f.exit1546, !dbg !131 + +__nv_exp2f.exit1546: ; preds = %1447, %1449 + %.0.i1545 = phi float [ %1448, %1447 ], [ %1450, %1449 ], !dbg !131 + %1451 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1547 = icmp eq i32 %1451, 0, !dbg !131 + br i1 %.not.i1547, label %1454, label %1452, !dbg !131 + +1452: ; preds = %__nv_exp2f.exit1546 + %1453 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1375) #3, !dbg !131 + br label %__nv_exp2f.exit1549, !dbg !131 + +1454: ; preds = %__nv_exp2f.exit1546 + %1455 = tail call float @llvm.nvvm.ex2.approx.f(float %1375) #3, !dbg !131 + br label %__nv_exp2f.exit1549, !dbg !131 + +__nv_exp2f.exit1549: ; preds = %1452, %1454 + %.0.i1548 = phi float [ %1453, %1452 ], [ %1455, %1454 ], !dbg !131 + %1456 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1550 = icmp eq i32 %1456, 0, !dbg !131 + br i1 %.not.i1550, label %1459, label %1457, !dbg !131 + +1457: ; preds = %__nv_exp2f.exit1549 + %1458 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1376) #3, !dbg !131 + br label %__nv_exp2f.exit1552, !dbg !131 + +1459: ; preds = %__nv_exp2f.exit1549 + %1460 = tail call float @llvm.nvvm.ex2.approx.f(float %1376) #3, !dbg !131 + br label %__nv_exp2f.exit1552, !dbg !131 + +__nv_exp2f.exit1552: ; preds = %1457, %1459 + %.0.i1551 = phi float [ %1458, %1457 ], [ %1460, %1459 ], !dbg !131 + %1461 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1553 = icmp eq i32 %1461, 0, !dbg !131 + br i1 %.not.i1553, label %1464, label %1462, !dbg !131 + +1462: ; preds = %__nv_exp2f.exit1552 + %1463 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1377) #3, !dbg !131 + br label %__nv_exp2f.exit1555, !dbg !131 + +1464: ; preds = %__nv_exp2f.exit1552 + %1465 = tail call float @llvm.nvvm.ex2.approx.f(float %1377) #3, !dbg !131 + br label %__nv_exp2f.exit1555, !dbg !131 + +__nv_exp2f.exit1555: ; preds = %1462, %1464 + %.0.i1554 = phi float [ %1463, %1462 ], [ %1465, %1464 ], !dbg !131 + %1466 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1556 = icmp eq i32 %1466, 0, !dbg !131 + br i1 %.not.i1556, label %1469, label %1467, !dbg !131 + +1467: ; preds = %__nv_exp2f.exit1555 + %1468 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1378) #3, !dbg !131 + br label %__nv_exp2f.exit1558, !dbg !131 + +1469: ; preds = %__nv_exp2f.exit1555 + %1470 = tail call float @llvm.nvvm.ex2.approx.f(float %1378) #3, !dbg !131 + br label %__nv_exp2f.exit1558, !dbg !131 + +__nv_exp2f.exit1558: ; preds = %1467, %1469 + %.0.i1557 = phi float [ %1468, %1467 ], [ %1470, %1469 ], !dbg !131 + %1471 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1559 = icmp eq i32 %1471, 0, !dbg !131 + br i1 %.not.i1559, label %1474, label %1472, !dbg !131 + +1472: ; preds = %__nv_exp2f.exit1558 + %1473 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1379) #3, !dbg !131 + br label %__nv_exp2f.exit1561, !dbg !131 + +1474: ; preds = %__nv_exp2f.exit1558 + %1475 = tail call float @llvm.nvvm.ex2.approx.f(float %1379) #3, !dbg !131 + br label %__nv_exp2f.exit1561, !dbg !131 + +__nv_exp2f.exit1561: ; preds = %1472, %1474 + %.0.i1560 = phi float [ %1473, %1472 ], [ %1475, %1474 ], !dbg !131 + %1476 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1562 = icmp eq i32 %1476, 0, !dbg !131 + br i1 %.not.i1562, label %1479, label %1477, !dbg !131 + +1477: ; preds = %__nv_exp2f.exit1561 + %1478 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1380) #3, !dbg !131 + br label %__nv_exp2f.exit1564, !dbg !131 + +1479: ; preds = %__nv_exp2f.exit1561 + %1480 = tail call float @llvm.nvvm.ex2.approx.f(float %1380) #3, !dbg !131 + br label %__nv_exp2f.exit1564, !dbg !131 + +__nv_exp2f.exit1564: ; preds = %1477, %1479 + %.0.i1563 = phi float [ %1478, %1477 ], [ %1480, %1479 ], !dbg !131 + %1481 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1565 = icmp eq i32 %1481, 0, !dbg !131 + br i1 %.not.i1565, label %1484, label %1482, !dbg !131 + +1482: ; preds = %__nv_exp2f.exit1564 + %1483 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1381) #3, !dbg !131 + br label %__nv_exp2f.exit1567, !dbg !131 + +1484: ; preds = %__nv_exp2f.exit1564 + %1485 = tail call float @llvm.nvvm.ex2.approx.f(float %1381) #3, !dbg !131 + br label %__nv_exp2f.exit1567, !dbg !131 + +__nv_exp2f.exit1567: ; preds = %1482, %1484 + %.0.i1566 = phi float [ %1483, %1482 ], [ %1485, %1484 ], !dbg !131 + %1486 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1568 = icmp eq i32 %1486, 0, !dbg !131 + br i1 %.not.i1568, label %1489, label %1487, !dbg !131 + +1487: ; preds = %__nv_exp2f.exit1567 + %1488 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1382) #3, !dbg !131 + br label %__nv_exp2f.exit1570, !dbg !131 + +1489: ; preds = %__nv_exp2f.exit1567 + %1490 = tail call float @llvm.nvvm.ex2.approx.f(float %1382) #3, !dbg !131 + br label %__nv_exp2f.exit1570, !dbg !131 + +__nv_exp2f.exit1570: ; preds = %1487, %1489 + %.0.i1569 = phi float [ %1488, %1487 ], [ %1490, %1489 ], !dbg !131 + %1491 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1571 = icmp eq i32 %1491, 0, !dbg !131 + br i1 %.not.i1571, label %1494, label %1492, !dbg !131 + +1492: ; preds = %__nv_exp2f.exit1570 + %1493 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1383) #3, !dbg !131 + br label %__nv_exp2f.exit1573, !dbg !131 + +1494: ; preds = %__nv_exp2f.exit1570 + %1495 = tail call float @llvm.nvvm.ex2.approx.f(float %1383) #3, !dbg !131 + br label %__nv_exp2f.exit1573, !dbg !131 + +__nv_exp2f.exit1573: ; preds = %1492, %1494 + %.0.i1572 = phi float [ %1493, %1492 ], [ %1495, %1494 ], !dbg !131 + %1496 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1574 = icmp eq i32 %1496, 0, !dbg !131 + br i1 %.not.i1574, label %1499, label %1497, !dbg !131 + +1497: ; preds = %__nv_exp2f.exit1573 + %1498 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1384) #3, !dbg !131 + br label %__nv_exp2f.exit1576, !dbg !131 + +1499: ; preds = %__nv_exp2f.exit1573 + %1500 = tail call float @llvm.nvvm.ex2.approx.f(float %1384) #3, !dbg !131 + br label %__nv_exp2f.exit1576, !dbg !131 + +__nv_exp2f.exit1576: ; preds = %1497, %1499 + %.0.i1575 = phi float [ %1498, %1497 ], [ %1500, %1499 ], !dbg !131 + %1501 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1577 = icmp eq i32 %1501, 0, !dbg !131 + br i1 %.not.i1577, label %1504, label %1502, !dbg !131 + +1502: ; preds = %__nv_exp2f.exit1576 + %1503 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1385) #3, !dbg !131 + br label %__nv_exp2f.exit1579, !dbg !131 + +1504: ; preds = %__nv_exp2f.exit1576 + %1505 = tail call float @llvm.nvvm.ex2.approx.f(float %1385) #3, !dbg !131 + br label %__nv_exp2f.exit1579, !dbg !131 + +__nv_exp2f.exit1579: ; preds = %1502, %1504 + %.0.i1578 = phi float [ %1503, %1502 ], [ %1505, %1504 ], !dbg !131 + %1506 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1580 = icmp eq i32 %1506, 0, !dbg !131 + br i1 %.not.i1580, label %1509, label %1507, !dbg !131 + +1507: ; preds = %__nv_exp2f.exit1579 + %1508 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1386) #3, !dbg !131 + br label %__nv_exp2f.exit1582, !dbg !131 + +1509: ; preds = %__nv_exp2f.exit1579 + %1510 = tail call float @llvm.nvvm.ex2.approx.f(float %1386) #3, !dbg !131 + br label %__nv_exp2f.exit1582, !dbg !131 + +__nv_exp2f.exit1582: ; preds = %1507, %1509 + %.0.i1581 = phi float [ %1508, %1507 ], [ %1510, %1509 ], !dbg !131 + %1511 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1583 = icmp eq i32 %1511, 0, !dbg !131 + br i1 %.not.i1583, label %1514, label %1512, !dbg !131 + +1512: ; preds = %__nv_exp2f.exit1582 + %1513 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1387) #3, !dbg !131 + br label %__nv_exp2f.exit1585, !dbg !131 + +1514: ; preds = %__nv_exp2f.exit1582 + %1515 = tail call float @llvm.nvvm.ex2.approx.f(float %1387) #3, !dbg !131 + br label %__nv_exp2f.exit1585, !dbg !131 + +__nv_exp2f.exit1585: ; preds = %1512, %1514 + %.0.i1584 = phi float [ %1513, %1512 ], [ %1515, %1514 ], !dbg !131 + %1516 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1586 = icmp eq i32 %1516, 0, !dbg !131 + br i1 %.not.i1586, label %1519, label %1517, !dbg !131 + +1517: ; preds = %__nv_exp2f.exit1585 + %1518 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1388) #3, !dbg !131 + br label %__nv_exp2f.exit1588, !dbg !131 + +1519: ; preds = %__nv_exp2f.exit1585 + %1520 = tail call float @llvm.nvvm.ex2.approx.f(float %1388) #3, !dbg !131 + br label %__nv_exp2f.exit1588, !dbg !131 + +__nv_exp2f.exit1588: ; preds = %1517, %1519 + %.0.i1587 = phi float [ %1518, %1517 ], [ %1520, %1519 ], !dbg !131 + %1521 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1589 = icmp eq i32 %1521, 0, !dbg !131 + br i1 %.not.i1589, label %1524, label %1522, !dbg !131 + +1522: ; preds = %__nv_exp2f.exit1588 + %1523 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1389) #3, !dbg !131 + br label %__nv_exp2f.exit1591, !dbg !131 + +1524: ; preds = %__nv_exp2f.exit1588 + %1525 = tail call float @llvm.nvvm.ex2.approx.f(float %1389) #3, !dbg !131 + br label %__nv_exp2f.exit1591, !dbg !131 + +__nv_exp2f.exit1591: ; preds = %1522, %1524 + %.0.i1590 = phi float [ %1523, %1522 ], [ %1525, %1524 ], !dbg !131 + %1526 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1592 = icmp eq i32 %1526, 0, !dbg !131 + br i1 %.not.i1592, label %1529, label %1527, !dbg !131 + +1527: ; preds = %__nv_exp2f.exit1591 + %1528 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1390) #3, !dbg !131 + br label %__nv_exp2f.exit1594, !dbg !131 + +1529: ; preds = %__nv_exp2f.exit1591 + %1530 = tail call float @llvm.nvvm.ex2.approx.f(float %1390) #3, !dbg !131 + br label %__nv_exp2f.exit1594, !dbg !131 + +__nv_exp2f.exit1594: ; preds = %1527, %1529 + %.0.i1593 = phi float [ %1528, %1527 ], [ %1530, %1529 ], !dbg !131 + %1531 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1595 = icmp eq i32 %1531, 0, !dbg !131 + br i1 %.not.i1595, label %1534, label %1532, !dbg !131 + +1532: ; preds = %__nv_exp2f.exit1594 + %1533 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1391) #3, !dbg !131 + br label %__nv_exp2f.exit1597, !dbg !131 + +1534: ; preds = %__nv_exp2f.exit1594 + %1535 = tail call float @llvm.nvvm.ex2.approx.f(float %1391) #3, !dbg !131 + br label %__nv_exp2f.exit1597, !dbg !131 + +__nv_exp2f.exit1597: ; preds = %1532, %1534 + %.0.i1596 = phi float [ %1533, %1532 ], [ %1535, %1534 ], !dbg !131 + %1536 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1598 = icmp eq i32 %1536, 0, !dbg !131 + br i1 %.not.i1598, label %1539, label %1537, !dbg !131 + +1537: ; preds = %__nv_exp2f.exit1597 + %1538 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1392) #3, !dbg !131 + br label %__nv_exp2f.exit1600, !dbg !131 + +1539: ; preds = %__nv_exp2f.exit1597 + %1540 = tail call float @llvm.nvvm.ex2.approx.f(float %1392) #3, !dbg !131 + br label %__nv_exp2f.exit1600, !dbg !131 + +__nv_exp2f.exit1600: ; preds = %1537, %1539 + %.0.i1599 = phi float [ %1538, %1537 ], [ %1540, %1539 ], !dbg !131 + %1541 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1601 = icmp eq i32 %1541, 0, !dbg !131 + br i1 %.not.i1601, label %1544, label %1542, !dbg !131 + +1542: ; preds = %__nv_exp2f.exit1600 + %1543 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1393) #3, !dbg !131 + br label %__nv_exp2f.exit1603, !dbg !131 + +1544: ; preds = %__nv_exp2f.exit1600 + %1545 = tail call float @llvm.nvvm.ex2.approx.f(float %1393) #3, !dbg !131 + br label %__nv_exp2f.exit1603, !dbg !131 + +__nv_exp2f.exit1603: ; preds = %1542, %1544 + %.0.i1602 = phi float [ %1543, %1542 ], [ %1545, %1544 ], !dbg !131 + %1546 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1604 = icmp eq i32 %1546, 0, !dbg !131 + br i1 %.not.i1604, label %1549, label %1547, !dbg !131 + +1547: ; preds = %__nv_exp2f.exit1603 + %1548 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1394) #3, !dbg !131 + br label %__nv_exp2f.exit1606, !dbg !131 + +1549: ; preds = %__nv_exp2f.exit1603 + %1550 = tail call float @llvm.nvvm.ex2.approx.f(float %1394) #3, !dbg !131 + br label %__nv_exp2f.exit1606, !dbg !131 + +__nv_exp2f.exit1606: ; preds = %1547, %1549 + %.0.i1605 = phi float [ %1548, %1547 ], [ %1550, %1549 ], !dbg !131 + %1551 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !131 + %.not.i1607 = icmp eq i32 %1551, 0, !dbg !131 + br i1 %.not.i1607, label %1554, label %1552, !dbg !131 + +1552: ; preds = %__nv_exp2f.exit1606 + %1553 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %1395) #3, !dbg !131 + br label %__nv_exp2f.exit1609, !dbg !131 + +1554: ; preds = %__nv_exp2f.exit1606 + %1555 = tail call float @llvm.nvvm.ex2.approx.f(float %1395) #3, !dbg !131 + br label %__nv_exp2f.exit1609, !dbg !131 + +__nv_exp2f.exit1609: ; preds = %1552, %1554 + %.0.i1608 = phi float [ %1553, %1552 ], [ %1555, %1554 ], !dbg !131 + %1556 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %623, !dbg !103 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !132 + %1557 = add i32 %627, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1558 = lshr exact i32 %1557, 4, !dbg !132 + %1559 = and i32 %1558, 16383, !dbg !132 + %1560 = zext nneg i32 %1559 to i64, !dbg !132 + %1561 = or disjoint i64 %1560, 4611686293372403712, !dbg !132 + %1562 = ptrtoint ptr addrspace(3) %1556 to i32, !dbg !132 + %1563 = lshr exact i32 %1562, 4, !dbg !132 + %1564 = and i32 %1563, 16383, !dbg !132 + %1565 = zext nneg i32 %1564 to i64, !dbg !132 + %1566 = or disjoint i64 %1565, 4611686293338849280, !dbg !132 + %1567 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %1561, i64 %1566) #3, !dbg !132 + %1568 = add i32 %639, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1569 = lshr exact i32 %1568, 4, !dbg !132 + %1570 = and i32 %1569, 16383, !dbg !132 + %1571 = zext nneg i32 %1570 to i64, !dbg !132 + %1572 = or disjoint i64 %1571, 4611686293372403712, !dbg !132 + %1573 = add i32 %1562, 32, !dbg !132 + %1574 = lshr exact i32 %1573, 4, !dbg !132 + %1575 = and i32 %1574, 16383, !dbg !132 + %1576 = zext nneg i32 %1575 to i64, !dbg !132 + %1577 = or disjoint i64 %1576, 4611686293338849280, !dbg !132 + %1578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 0, !dbg !132 + %1579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 1, !dbg !132 + %1580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 2, !dbg !132 + %1581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 3, !dbg !132 + %1582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 4, !dbg !132 + %1583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 5, !dbg !132 + %1584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 6, !dbg !132 + %1585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 7, !dbg !132 + %1586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 8, !dbg !132 + %1587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 9, !dbg !132 + %1588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 10, !dbg !132 + %1589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 11, !dbg !132 + %1590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 12, !dbg !132 + %1591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 13, !dbg !132 + %1592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 14, !dbg !132 + %1593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 15, !dbg !132 + %1594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 16, !dbg !132 + %1595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 17, !dbg !132 + %1596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 18, !dbg !132 + %1597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 19, !dbg !132 + %1598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 20, !dbg !132 + %1599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 21, !dbg !132 + %1600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 22, !dbg !132 + %1601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 23, !dbg !132 + %1602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 24, !dbg !132 + %1603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 25, !dbg !132 + %1604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 26, !dbg !132 + %1605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 27, !dbg !132 + %1606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 28, !dbg !132 + %1607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 29, !dbg !132 + %1608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 30, !dbg !132 + %1609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1567, 31, !dbg !132 + %1610 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1578, float %1579, float %1580, float %1581, float %1582, float %1583, float %1584, float %1585, float %1586, float %1587, float %1588, float %1589, float %1590, float %1591, float %1592, float %1593, float %1594, float %1595, float %1596, float %1597, float %1598, float %1599, float %1600, float %1601, float %1602, float %1603, float %1604, float %1605, float %1606, float %1607, float %1608, float %1609, i64 %1572, i64 %1577, i1 true) #3, !dbg !132 + %1611 = add i32 %683, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1612 = lshr exact i32 %1611, 4, !dbg !132 + %1613 = and i32 %1612, 16383, !dbg !132 + %1614 = zext nneg i32 %1613 to i64, !dbg !132 + %1615 = or disjoint i64 %1614, 4611686293372403712, !dbg !132 + %1616 = add i32 %1562, 64, !dbg !132 + %1617 = lshr exact i32 %1616, 4, !dbg !132 + %1618 = and i32 %1617, 16383, !dbg !132 + %1619 = zext nneg i32 %1618 to i64, !dbg !132 + %1620 = or disjoint i64 %1619, 4611686293338849280, !dbg !132 + %1621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 0, !dbg !132 + %1622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 1, !dbg !132 + %1623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 2, !dbg !132 + %1624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 3, !dbg !132 + %1625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 4, !dbg !132 + %1626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 5, !dbg !132 + %1627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 6, !dbg !132 + %1628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 7, !dbg !132 + %1629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 8, !dbg !132 + %1630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 9, !dbg !132 + %1631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 10, !dbg !132 + %1632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 11, !dbg !132 + %1633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 12, !dbg !132 + %1634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 13, !dbg !132 + %1635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 14, !dbg !132 + %1636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 15, !dbg !132 + %1637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 16, !dbg !132 + %1638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 17, !dbg !132 + %1639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 18, !dbg !132 + %1640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 19, !dbg !132 + %1641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 20, !dbg !132 + %1642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 21, !dbg !132 + %1643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 22, !dbg !132 + %1644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 23, !dbg !132 + %1645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 24, !dbg !132 + %1646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 25, !dbg !132 + %1647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 26, !dbg !132 + %1648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 27, !dbg !132 + %1649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 28, !dbg !132 + %1650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 29, !dbg !132 + %1651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 30, !dbg !132 + %1652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1610, 31, !dbg !132 + %1653 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1621, float %1622, float %1623, float %1624, float %1625, float %1626, float %1627, float %1628, float %1629, float %1630, float %1631, float %1632, float %1633, float %1634, float %1635, float %1636, float %1637, float %1638, float %1639, float %1640, float %1641, float %1642, float %1643, float %1644, float %1645, float %1646, float %1647, float %1648, float %1649, float %1650, float %1651, float %1652, i64 %1615, i64 %1620, i1 true) #3, !dbg !132 + %1654 = add i32 %727, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1655 = lshr exact i32 %1654, 4, !dbg !132 + %1656 = and i32 %1655, 16383, !dbg !132 + %1657 = zext nneg i32 %1656 to i64, !dbg !132 + %1658 = or disjoint i64 %1657, 4611686293372403712, !dbg !132 + %1659 = add i32 %1562, 96, !dbg !132 + %1660 = lshr exact i32 %1659, 4, !dbg !132 + %1661 = and i32 %1660, 16383, !dbg !132 + %1662 = zext nneg i32 %1661 to i64, !dbg !132 + %1663 = or disjoint i64 %1662, 4611686293338849280, !dbg !132 + %1664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 0, !dbg !132 + %1665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 1, !dbg !132 + %1666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 2, !dbg !132 + %1667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 3, !dbg !132 + %1668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 4, !dbg !132 + %1669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 5, !dbg !132 + %1670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 6, !dbg !132 + %1671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 7, !dbg !132 + %1672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 8, !dbg !132 + %1673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 9, !dbg !132 + %1674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 10, !dbg !132 + %1675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 11, !dbg !132 + %1676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 12, !dbg !132 + %1677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 13, !dbg !132 + %1678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 14, !dbg !132 + %1679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 15, !dbg !132 + %1680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 16, !dbg !132 + %1681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 17, !dbg !132 + %1682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 18, !dbg !132 + %1683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 19, !dbg !132 + %1684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 20, !dbg !132 + %1685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 21, !dbg !132 + %1686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 22, !dbg !132 + %1687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 23, !dbg !132 + %1688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 24, !dbg !132 + %1689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 25, !dbg !132 + %1690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 26, !dbg !132 + %1691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 27, !dbg !132 + %1692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 28, !dbg !132 + %1693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 29, !dbg !132 + %1694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 30, !dbg !132 + %1695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1653, 31, !dbg !132 + %1696 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1664, float %1665, float %1666, float %1667, float %1668, float %1669, float %1670, float %1671, float %1672, float %1673, float %1674, float %1675, float %1676, float %1677, float %1678, float %1679, float %1680, float %1681, float %1682, float %1683, float %1684, float %1685, float %1686, float %1687, float %1688, float %1689, float %1690, float %1691, float %1692, float %1693, float %1694, float %1695, i64 %1658, i64 %1663, i1 true) #3, !dbg !132 + %1697 = add i32 %771, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1698 = lshr exact i32 %1697, 4, !dbg !132 + %1699 = and i32 %1698, 16383, !dbg !132 + %1700 = zext nneg i32 %1699 to i64, !dbg !132 + %1701 = or disjoint i64 %1700, 4611686293372403712, !dbg !132 + %1702 = add i32 %1562, 8192, !dbg !132 + %1703 = lshr exact i32 %1702, 4, !dbg !132 + %1704 = and i32 %1703, 16383, !dbg !132 + %1705 = zext nneg i32 %1704 to i64, !dbg !132 + %1706 = or disjoint i64 %1705, 4611686293338849280, !dbg !132 + %1707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 0, !dbg !132 + %1708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 1, !dbg !132 + %1709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 2, !dbg !132 + %1710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 3, !dbg !132 + %1711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 4, !dbg !132 + %1712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 5, !dbg !132 + %1713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 6, !dbg !132 + %1714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 7, !dbg !132 + %1715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 8, !dbg !132 + %1716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 9, !dbg !132 + %1717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 10, !dbg !132 + %1718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 11, !dbg !132 + %1719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 12, !dbg !132 + %1720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 13, !dbg !132 + %1721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 14, !dbg !132 + %1722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 15, !dbg !132 + %1723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 16, !dbg !132 + %1724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 17, !dbg !132 + %1725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 18, !dbg !132 + %1726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 19, !dbg !132 + %1727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 20, !dbg !132 + %1728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 21, !dbg !132 + %1729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 22, !dbg !132 + %1730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 23, !dbg !132 + %1731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 24, !dbg !132 + %1732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 25, !dbg !132 + %1733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 26, !dbg !132 + %1734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 27, !dbg !132 + %1735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 28, !dbg !132 + %1736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 29, !dbg !132 + %1737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 30, !dbg !132 + %1738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1696, 31, !dbg !132 + %1739 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1707, float %1708, float %1709, float %1710, float %1711, float %1712, float %1713, float %1714, float %1715, float %1716, float %1717, float %1718, float %1719, float %1720, float %1721, float %1722, float %1723, float %1724, float %1725, float %1726, float %1727, float %1728, float %1729, float %1730, float %1731, float %1732, float %1733, float %1734, float %1735, float %1736, float %1737, float %1738, i64 %1701, i64 %1706, i1 true) #3, !dbg !132 + %1740 = add i32 %815, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1741 = lshr exact i32 %1740, 4, !dbg !132 + %1742 = and i32 %1741, 16383, !dbg !132 + %1743 = zext nneg i32 %1742 to i64, !dbg !132 + %1744 = or disjoint i64 %1743, 4611686293372403712, !dbg !132 + %1745 = add i32 %1562, 8224, !dbg !132 + %1746 = lshr exact i32 %1745, 4, !dbg !132 + %1747 = and i32 %1746, 16383, !dbg !132 + %1748 = zext nneg i32 %1747 to i64, !dbg !132 + %1749 = or disjoint i64 %1748, 4611686293338849280, !dbg !132 + %1750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 0, !dbg !132 + %1751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 1, !dbg !132 + %1752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 2, !dbg !132 + %1753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 3, !dbg !132 + %1754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 4, !dbg !132 + %1755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 5, !dbg !132 + %1756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 6, !dbg !132 + %1757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 7, !dbg !132 + %1758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 8, !dbg !132 + %1759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 9, !dbg !132 + %1760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 10, !dbg !132 + %1761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 11, !dbg !132 + %1762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 12, !dbg !132 + %1763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 13, !dbg !132 + %1764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 14, !dbg !132 + %1765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 15, !dbg !132 + %1766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 16, !dbg !132 + %1767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 17, !dbg !132 + %1768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 18, !dbg !132 + %1769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 19, !dbg !132 + %1770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 20, !dbg !132 + %1771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 21, !dbg !132 + %1772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 22, !dbg !132 + %1773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 23, !dbg !132 + %1774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 24, !dbg !132 + %1775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 25, !dbg !132 + %1776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 26, !dbg !132 + %1777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 27, !dbg !132 + %1778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 28, !dbg !132 + %1779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 29, !dbg !132 + %1780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 30, !dbg !132 + %1781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1739, 31, !dbg !132 + %1782 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1750, float %1751, float %1752, float %1753, float %1754, float %1755, float %1756, float %1757, float %1758, float %1759, float %1760, float %1761, float %1762, float %1763, float %1764, float %1765, float %1766, float %1767, float %1768, float %1769, float %1770, float %1771, float %1772, float %1773, float %1774, float %1775, float %1776, float %1777, float %1778, float %1779, float %1780, float %1781, i64 %1744, i64 %1749, i1 true) #3, !dbg !132 + %1783 = add i32 %859, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1784 = lshr exact i32 %1783, 4, !dbg !132 + %1785 = and i32 %1784, 16383, !dbg !132 + %1786 = zext nneg i32 %1785 to i64, !dbg !132 + %1787 = or disjoint i64 %1786, 4611686293372403712, !dbg !132 + %1788 = add i32 %1562, 8256, !dbg !132 + %1789 = lshr exact i32 %1788, 4, !dbg !132 + %1790 = and i32 %1789, 16383, !dbg !132 + %1791 = zext nneg i32 %1790 to i64, !dbg !132 + %1792 = or disjoint i64 %1791, 4611686293338849280, !dbg !132 + %1793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 0, !dbg !132 + %1794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 1, !dbg !132 + %1795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 2, !dbg !132 + %1796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 3, !dbg !132 + %1797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 4, !dbg !132 + %1798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 5, !dbg !132 + %1799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 6, !dbg !132 + %1800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 7, !dbg !132 + %1801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 8, !dbg !132 + %1802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 9, !dbg !132 + %1803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 10, !dbg !132 + %1804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 11, !dbg !132 + %1805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 12, !dbg !132 + %1806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 13, !dbg !132 + %1807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 14, !dbg !132 + %1808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 15, !dbg !132 + %1809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 16, !dbg !132 + %1810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 17, !dbg !132 + %1811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 18, !dbg !132 + %1812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 19, !dbg !132 + %1813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 20, !dbg !132 + %1814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 21, !dbg !132 + %1815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 22, !dbg !132 + %1816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 23, !dbg !132 + %1817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 24, !dbg !132 + %1818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 25, !dbg !132 + %1819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 26, !dbg !132 + %1820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 27, !dbg !132 + %1821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 28, !dbg !132 + %1822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 29, !dbg !132 + %1823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 30, !dbg !132 + %1824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1782, 31, !dbg !132 + %1825 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1793, float %1794, float %1795, float %1796, float %1797, float %1798, float %1799, float %1800, float %1801, float %1802, float %1803, float %1804, float %1805, float %1806, float %1807, float %1808, float %1809, float %1810, float %1811, float %1812, float %1813, float %1814, float %1815, float %1816, float %1817, float %1818, float %1819, float %1820, float %1821, float %1822, float %1823, float %1824, i64 %1787, i64 %1792, i1 true) #3, !dbg !132 + %1826 = add i32 %903, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !132 + %1827 = lshr exact i32 %1826, 4, !dbg !132 + %1828 = and i32 %1827, 16383, !dbg !132 + %1829 = zext nneg i32 %1828 to i64, !dbg !132 + %1830 = or disjoint i64 %1829, 4611686293372403712, !dbg !132 + %1831 = add i32 %1562, 8288, !dbg !132 + %1832 = lshr exact i32 %1831, 4, !dbg !132 + %1833 = and i32 %1832, 16383, !dbg !132 + %1834 = zext nneg i32 %1833 to i64, !dbg !132 + %1835 = or disjoint i64 %1834, 4611686293338849280, !dbg !132 + %1836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 0, !dbg !132 + %1837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 1, !dbg !132 + %1838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 2, !dbg !132 + %1839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 3, !dbg !132 + %1840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 4, !dbg !132 + %1841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 5, !dbg !132 + %1842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 6, !dbg !132 + %1843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 7, !dbg !132 + %1844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 8, !dbg !132 + %1845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 9, !dbg !132 + %1846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 10, !dbg !132 + %1847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 11, !dbg !132 + %1848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 12, !dbg !132 + %1849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 13, !dbg !132 + %1850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 14, !dbg !132 + %1851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 15, !dbg !132 + %1852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 16, !dbg !132 + %1853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 17, !dbg !132 + %1854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 18, !dbg !132 + %1855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 19, !dbg !132 + %1856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 20, !dbg !132 + %1857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 21, !dbg !132 + %1858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 22, !dbg !132 + %1859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 23, !dbg !132 + %1860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 24, !dbg !132 + %1861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 25, !dbg !132 + %1862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 26, !dbg !132 + %1863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 27, !dbg !132 + %1864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 28, !dbg !132 + %1865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 29, !dbg !132 + %1866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 30, !dbg !132 + %1867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1825, 31, !dbg !132 + %1868 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %1836, float %1837, float %1838, float %1839, float %1840, float %1841, float %1842, float %1843, float %1844, float %1845, float %1846, float %1847, float %1848, float %1849, float %1850, float %1851, float %1852, float %1853, float %1854, float %1855, float %1856, float %1857, float %1858, float %1859, float %1860, float %1861, float %1862, float %1863, float %1864, float %1865, float %1866, float %1867, i64 %1830, i64 %1835, i1 true) #3, !dbg !132 + %1869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 0, !dbg !132 + %1870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 1, !dbg !132 + %1871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 2, !dbg !132 + %1872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 3, !dbg !132 + %1873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 4, !dbg !132 + %1874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 5, !dbg !132 + %1875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 6, !dbg !132 + %1876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 7, !dbg !132 + %1877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 8, !dbg !132 + %1878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 9, !dbg !132 + %1879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 10, !dbg !132 + %1880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 11, !dbg !132 + %1881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 12, !dbg !132 + %1882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 13, !dbg !132 + %1883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 14, !dbg !132 + %1884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 15, !dbg !132 + %1885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 16, !dbg !132 + %1886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 17, !dbg !132 + %1887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 18, !dbg !132 + %1888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 19, !dbg !132 + %1889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 20, !dbg !132 + %1890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 21, !dbg !132 + %1891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 22, !dbg !132 + %1892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 23, !dbg !132 + %1893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 24, !dbg !132 + %1894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 25, !dbg !132 + %1895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 26, !dbg !132 + %1896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 27, !dbg !132 + %1897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 28, !dbg !132 + %1898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 29, !dbg !132 + %1899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 30, !dbg !132 + %1900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %1868, 31, !dbg !132 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !132 + %1901 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %1869, float %1870, float %1871, float %1872, float %1873, float %1874, float %1875, float %1876, float %1877, float %1878, float %1879, float %1880, float %1881, float %1882, float %1883, float %1884, float %1885, float %1886, float %1887, float %1888, float %1889, float %1890, float %1891, float %1892, float %1893, float %1894, float %1895, float %1896, float %1897, float %1898, float %1899, float %1900, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %1556, i32 0, i32 0) #3, !dbg !132 + %1902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 0, !dbg !132 + %1903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 1, !dbg !132 + %1904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 2, !dbg !132 + %1905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 3, !dbg !132 + %1906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 4, !dbg !132 + %1907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 5, !dbg !132 + %1908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 6, !dbg !132 + %1909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 7, !dbg !132 + %1910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 8, !dbg !132 + %1911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 9, !dbg !132 + %1912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 10, !dbg !132 + %1913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 11, !dbg !132 + %1914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 12, !dbg !132 + %1915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 13, !dbg !132 + %1916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 14, !dbg !132 + %1917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 15, !dbg !132 + %1918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 16, !dbg !132 + %1919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 17, !dbg !132 + %1920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 18, !dbg !132 + %1921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 19, !dbg !132 + %1922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 20, !dbg !132 + %1923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 21, !dbg !132 + %1924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 22, !dbg !132 + %1925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 23, !dbg !132 + %1926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 24, !dbg !132 + %1927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 25, !dbg !132 + %1928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 26, !dbg !132 + %1929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 27, !dbg !132 + %1930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 28, !dbg !132 + %1931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 29, !dbg !132 + %1932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 30, !dbg !132 + %1933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %1901, 31, !dbg !132 + %1934 = insertelement <2 x float> poison, float %1902, i64 0, !dbg !118 + %1935 = insertelement <2 x float> %1934, float %1903, i64 1, !dbg !118 + %1936 = fsub <2 x float> %1935, %532, !dbg !118 + %1937 = insertelement <2 x float> poison, float %.0.i1515, i64 0, !dbg !133 + %1938 = insertelement <2 x float> %1937, float %.0.i1518, i64 1, !dbg !133 + %1939 = fmul <2 x float> %1938, %1936, !dbg !133 + %1940 = fptrunc <2 x float> %1939 to <2 x bfloat>, !dbg !134 + %1941 = select <2 x i1> %1070, <2 x bfloat> %1940, <2 x bfloat> zeroinitializer, !dbg !135 + %1942 = insertelement <2 x float> poison, float %1904, i64 0, !dbg !118 + %1943 = insertelement <2 x float> %1942, float %1905, i64 1, !dbg !118 + %1944 = fsub <2 x float> %1943, %522, !dbg !118 + %1945 = insertelement <2 x float> poison, float %.0.i1521, i64 0, !dbg !133 + %1946 = insertelement <2 x float> %1945, float %.0.i1524, i64 1, !dbg !133 + %1947 = fmul <2 x float> %1946, %1944, !dbg !133 + %1948 = fptrunc <2 x float> %1947 to <2 x bfloat>, !dbg !134 + %1949 = select <2 x i1> %1071, <2 x bfloat> %1948, <2 x bfloat> zeroinitializer, !dbg !135 + %1950 = insertelement <2 x float> poison, float %1906, i64 0, !dbg !118 + %1951 = insertelement <2 x float> %1950, float %1907, i64 1, !dbg !118 + %1952 = fsub <2 x float> %1951, %532, !dbg !118 + %1953 = insertelement <2 x float> poison, float %.0.i1527, i64 0, !dbg !133 + %1954 = insertelement <2 x float> %1953, float %.0.i1530, i64 1, !dbg !133 + %1955 = fmul <2 x float> %1954, %1952, !dbg !133 + %1956 = fptrunc <2 x float> %1955 to <2 x bfloat>, !dbg !134 + %1957 = select <2 x i1> %1098, <2 x bfloat> %1956, <2 x bfloat> zeroinitializer, !dbg !135 + %1958 = insertelement <2 x float> poison, float %1908, i64 0, !dbg !118 + %1959 = insertelement <2 x float> %1958, float %1909, i64 1, !dbg !118 + %1960 = fsub <2 x float> %1959, %522, !dbg !118 + %1961 = insertelement <2 x float> poison, float %.0.i1533, i64 0, !dbg !133 + %1962 = insertelement <2 x float> %1961, float %.0.i1536, i64 1, !dbg !133 + %1963 = fmul <2 x float> %1962, %1960, !dbg !133 + %1964 = fptrunc <2 x float> %1963 to <2 x bfloat>, !dbg !134 + %1965 = select <2 x i1> %1099, <2 x bfloat> %1964, <2 x bfloat> zeroinitializer, !dbg !135 + %1966 = insertelement <2 x float> poison, float %1910, i64 0, !dbg !118 + %1967 = insertelement <2 x float> %1966, float %1911, i64 1, !dbg !118 + %1968 = fsub <2 x float> %1967, %532, !dbg !118 + %1969 = insertelement <2 x float> poison, float %.0.i1539, i64 0, !dbg !133 + %1970 = insertelement <2 x float> %1969, float %.0.i1542, i64 1, !dbg !133 + %1971 = fmul <2 x float> %1970, %1968, !dbg !133 + %1972 = fptrunc <2 x float> %1971 to <2 x bfloat>, !dbg !134 + %1973 = select <2 x i1> %1126, <2 x bfloat> %1972, <2 x bfloat> zeroinitializer, !dbg !135 + %1974 = insertelement <2 x float> poison, float %1912, i64 0, !dbg !118 + %1975 = insertelement <2 x float> %1974, float %1913, i64 1, !dbg !118 + %1976 = fsub <2 x float> %1975, %522, !dbg !118 + %1977 = insertelement <2 x float> poison, float %.0.i1545, i64 0, !dbg !133 + %1978 = insertelement <2 x float> %1977, float %.0.i1548, i64 1, !dbg !133 + %1979 = fmul <2 x float> %1978, %1976, !dbg !133 + %1980 = fptrunc <2 x float> %1979 to <2 x bfloat>, !dbg !134 + %1981 = select <2 x i1> %1127, <2 x bfloat> %1980, <2 x bfloat> zeroinitializer, !dbg !135 + %1982 = insertelement <2 x float> poison, float %1914, i64 0, !dbg !118 + %1983 = insertelement <2 x float> %1982, float %1915, i64 1, !dbg !118 + %1984 = fsub <2 x float> %1983, %532, !dbg !118 + %1985 = insertelement <2 x float> poison, float %.0.i1551, i64 0, !dbg !133 + %1986 = insertelement <2 x float> %1985, float %.0.i1554, i64 1, !dbg !133 + %1987 = fmul <2 x float> %1986, %1984, !dbg !133 + %1988 = fptrunc <2 x float> %1987 to <2 x bfloat>, !dbg !134 + %1989 = select <2 x i1> %1154, <2 x bfloat> %1988, <2 x bfloat> zeroinitializer, !dbg !135 + %1990 = insertelement <2 x float> poison, float %1916, i64 0, !dbg !118 + %1991 = insertelement <2 x float> %1990, float %1917, i64 1, !dbg !118 + %1992 = fsub <2 x float> %1991, %522, !dbg !118 + %1993 = insertelement <2 x float> poison, float %.0.i1557, i64 0, !dbg !133 + %1994 = insertelement <2 x float> %1993, float %.0.i1560, i64 1, !dbg !133 + %1995 = fmul <2 x float> %1994, %1992, !dbg !133 + %1996 = fptrunc <2 x float> %1995 to <2 x bfloat>, !dbg !134 + %1997 = select <2 x i1> %1155, <2 x bfloat> %1996, <2 x bfloat> zeroinitializer, !dbg !135 + %1998 = insertelement <2 x float> poison, float %1918, i64 0, !dbg !118 + %1999 = insertelement <2 x float> %1998, float %1919, i64 1, !dbg !118 + %2000 = fsub <2 x float> %1999, %532, !dbg !118 + %2001 = insertelement <2 x float> poison, float %.0.i1563, i64 0, !dbg !133 + %2002 = insertelement <2 x float> %2001, float %.0.i1566, i64 1, !dbg !133 + %2003 = fmul <2 x float> %2002, %2000, !dbg !133 + %2004 = fptrunc <2 x float> %2003 to <2 x bfloat>, !dbg !134 + %2005 = select <2 x i1> %1182, <2 x bfloat> %2004, <2 x bfloat> zeroinitializer, !dbg !135 + %2006 = insertelement <2 x float> poison, float %1920, i64 0, !dbg !118 + %2007 = insertelement <2 x float> %2006, float %1921, i64 1, !dbg !118 + %2008 = fsub <2 x float> %2007, %522, !dbg !118 + %2009 = insertelement <2 x float> poison, float %.0.i1569, i64 0, !dbg !133 + %2010 = insertelement <2 x float> %2009, float %.0.i1572, i64 1, !dbg !133 + %2011 = fmul <2 x float> %2010, %2008, !dbg !133 + %2012 = fptrunc <2 x float> %2011 to <2 x bfloat>, !dbg !134 + %2013 = select <2 x i1> %1183, <2 x bfloat> %2012, <2 x bfloat> zeroinitializer, !dbg !135 + %2014 = insertelement <2 x float> poison, float %1922, i64 0, !dbg !118 + %2015 = insertelement <2 x float> %2014, float %1923, i64 1, !dbg !118 + %2016 = fsub <2 x float> %2015, %532, !dbg !118 + %2017 = insertelement <2 x float> poison, float %.0.i1575, i64 0, !dbg !133 + %2018 = insertelement <2 x float> %2017, float %.0.i1578, i64 1, !dbg !133 + %2019 = fmul <2 x float> %2018, %2016, !dbg !133 + %2020 = fptrunc <2 x float> %2019 to <2 x bfloat>, !dbg !134 + %2021 = select <2 x i1> %1210, <2 x bfloat> %2020, <2 x bfloat> zeroinitializer, !dbg !135 + %2022 = insertelement <2 x float> poison, float %1924, i64 0, !dbg !118 + %2023 = insertelement <2 x float> %2022, float %1925, i64 1, !dbg !118 + %2024 = fsub <2 x float> %2023, %522, !dbg !118 + %2025 = insertelement <2 x float> poison, float %.0.i1581, i64 0, !dbg !133 + %2026 = insertelement <2 x float> %2025, float %.0.i1584, i64 1, !dbg !133 + %2027 = fmul <2 x float> %2026, %2024, !dbg !133 + %2028 = fptrunc <2 x float> %2027 to <2 x bfloat>, !dbg !134 + %2029 = select <2 x i1> %1211, <2 x bfloat> %2028, <2 x bfloat> zeroinitializer, !dbg !135 + %2030 = insertelement <2 x float> poison, float %1926, i64 0, !dbg !118 + %2031 = insertelement <2 x float> %2030, float %1927, i64 1, !dbg !118 + %2032 = fsub <2 x float> %2031, %532, !dbg !118 + %2033 = insertelement <2 x float> poison, float %.0.i1587, i64 0, !dbg !133 + %2034 = insertelement <2 x float> %2033, float %.0.i1590, i64 1, !dbg !133 + %2035 = fmul <2 x float> %2034, %2032, !dbg !133 + %2036 = fptrunc <2 x float> %2035 to <2 x bfloat>, !dbg !134 + %2037 = select <2 x i1> %1238, <2 x bfloat> %2036, <2 x bfloat> zeroinitializer, !dbg !135 + %2038 = insertelement <2 x float> poison, float %1928, i64 0, !dbg !118 + %2039 = insertelement <2 x float> %2038, float %1929, i64 1, !dbg !118 + %2040 = fsub <2 x float> %2039, %522, !dbg !118 + %2041 = insertelement <2 x float> poison, float %.0.i1593, i64 0, !dbg !133 + %2042 = insertelement <2 x float> %2041, float %.0.i1596, i64 1, !dbg !133 + %2043 = fmul <2 x float> %2042, %2040, !dbg !133 + %2044 = fptrunc <2 x float> %2043 to <2 x bfloat>, !dbg !134 + %2045 = select <2 x i1> %1239, <2 x bfloat> %2044, <2 x bfloat> zeroinitializer, !dbg !135 + %2046 = insertelement <2 x float> poison, float %1930, i64 0, !dbg !118 + %2047 = insertelement <2 x float> %2046, float %1931, i64 1, !dbg !118 + %2048 = fsub <2 x float> %2047, %532, !dbg !118 + %2049 = insertelement <2 x float> poison, float %.0.i1599, i64 0, !dbg !133 + %2050 = insertelement <2 x float> %2049, float %.0.i1602, i64 1, !dbg !133 + %2051 = fmul <2 x float> %2050, %2048, !dbg !133 + %2052 = fptrunc <2 x float> %2051 to <2 x bfloat>, !dbg !134 + %2053 = select <2 x i1> %1266, <2 x bfloat> %2052, <2 x bfloat> zeroinitializer, !dbg !135 + %2054 = insertelement <2 x float> poison, float %1932, i64 0, !dbg !118 + %2055 = insertelement <2 x float> %2054, float %1933, i64 1, !dbg !118 + %2056 = fsub <2 x float> %2055, %522, !dbg !118 + %2057 = insertelement <2 x float> poison, float %.0.i1605, i64 0, !dbg !133 + %2058 = insertelement <2 x float> %2057, float %.0.i1608, i64 1, !dbg !133 + %2059 = fmul <2 x float> %2058, %2056, !dbg !133 + %2060 = fptrunc <2 x float> %2059 to <2 x bfloat>, !dbg !134 + %2061 = select <2 x i1> %1267, <2 x bfloat> %2060, <2 x bfloat> zeroinitializer, !dbg !135 + %2062 = bitcast <2 x bfloat> %1941 to i32, !dbg !136 + %2063 = bitcast <2 x bfloat> %1949 to i32, !dbg !136 + %2064 = bitcast <2 x bfloat> %1957 to i32, !dbg !136 + %2065 = bitcast <2 x bfloat> %1965 to i32, !dbg !136 + %2066 = bitcast <2 x bfloat> %1973 to i32, !dbg !136 + %2067 = bitcast <2 x bfloat> %1981 to i32, !dbg !136 + %2068 = bitcast <2 x bfloat> %1989 to i32, !dbg !136 + %2069 = bitcast <2 x bfloat> %1997 to i32, !dbg !136 + %2070 = bitcast <2 x bfloat> %2005 to i32, !dbg !136 + %2071 = bitcast <2 x bfloat> %2013 to i32, !dbg !136 + %2072 = bitcast <2 x bfloat> %2021 to i32, !dbg !136 + %2073 = bitcast <2 x bfloat> %2029 to i32, !dbg !136 + %2074 = bitcast <2 x bfloat> %2037 to i32, !dbg !136 + %2075 = bitcast <2 x bfloat> %2045 to i32, !dbg !136 + %2076 = bitcast <2 x bfloat> %2053 to i32, !dbg !136 + %2077 = bitcast <2 x bfloat> %2061 to i32, !dbg !136 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !136 + %2078 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %537, float %538, float %539, float %540, float %541, float %542, float %543, float %544, float %545, float %546, float %547, float %548, float %549, float %550, float %551, float %552, float %553, float %554, float %555, float %556, float %557, float %558, float %559, float %560, float %561, float %562, float %563, float %564, float %565, float %566, float %567, float %568, float %569, float %570, float %571, float %572, float %573, float %574, float %575, float %576, float %577, float %578, float %579, float %580, float %581, float %582, float %583, float %584, float %585, float %586, float %587, float %588, float %589, float %590, float %591, float %592, float %593, float %594, float %595, float %596, float %597, float %598, float %599, float %600, i32 %2062, i32 %2063, i32 %2064, i32 %2065, i64 %637, i1 true) #3, !dbg !136 + %2079 = add i32 %633, 2048, !dbg !136 + %2080 = lshr exact i32 %2079, 4, !dbg !136 + %2081 = and i32 %2080, 16383, !dbg !136 + %2082 = zext nneg i32 %2081 to i64, !dbg !136 + %2083 = or disjoint i64 %2082, 4611686293338849280, !dbg !136 + %2084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 0, !dbg !136 + %2085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 1, !dbg !136 + %2086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 2, !dbg !136 + %2087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 3, !dbg !136 + %2088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 4, !dbg !136 + %2089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 5, !dbg !136 + %2090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 6, !dbg !136 + %2091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 7, !dbg !136 + %2092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 8, !dbg !136 + %2093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 9, !dbg !136 + %2094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 10, !dbg !136 + %2095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 11, !dbg !136 + %2096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 12, !dbg !136 + %2097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 13, !dbg !136 + %2098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 14, !dbg !136 + %2099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 15, !dbg !136 + %2100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 16, !dbg !136 + %2101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 17, !dbg !136 + %2102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 18, !dbg !136 + %2103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 19, !dbg !136 + %2104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 20, !dbg !136 + %2105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 21, !dbg !136 + %2106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 22, !dbg !136 + %2107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 23, !dbg !136 + %2108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 24, !dbg !136 + %2109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 25, !dbg !136 + %2110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 26, !dbg !136 + %2111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 27, !dbg !136 + %2112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 28, !dbg !136 + %2113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 29, !dbg !136 + %2114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 30, !dbg !136 + %2115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 31, !dbg !136 + %2116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 32, !dbg !136 + %2117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 33, !dbg !136 + %2118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 34, !dbg !136 + %2119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 35, !dbg !136 + %2120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 36, !dbg !136 + %2121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 37, !dbg !136 + %2122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 38, !dbg !136 + %2123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 39, !dbg !136 + %2124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 40, !dbg !136 + %2125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 41, !dbg !136 + %2126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 42, !dbg !136 + %2127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 43, !dbg !136 + %2128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 44, !dbg !136 + %2129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 45, !dbg !136 + %2130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 46, !dbg !136 + %2131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 47, !dbg !136 + %2132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 48, !dbg !136 + %2133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 49, !dbg !136 + %2134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 50, !dbg !136 + %2135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 51, !dbg !136 + %2136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 52, !dbg !136 + %2137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 53, !dbg !136 + %2138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 54, !dbg !136 + %2139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 55, !dbg !136 + %2140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 56, !dbg !136 + %2141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 57, !dbg !136 + %2142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 58, !dbg !136 + %2143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 59, !dbg !136 + %2144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 60, !dbg !136 + %2145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 61, !dbg !136 + %2146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 62, !dbg !136 + %2147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2078, 63, !dbg !136 + %2148 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2084, float %2085, float %2086, float %2087, float %2088, float %2089, float %2090, float %2091, float %2092, float %2093, float %2094, float %2095, float %2096, float %2097, float %2098, float %2099, float %2100, float %2101, float %2102, float %2103, float %2104, float %2105, float %2106, float %2107, float %2108, float %2109, float %2110, float %2111, float %2112, float %2113, float %2114, float %2115, float %2116, float %2117, float %2118, float %2119, float %2120, float %2121, float %2122, float %2123, float %2124, float %2125, float %2126, float %2127, float %2128, float %2129, float %2130, float %2131, float %2132, float %2133, float %2134, float %2135, float %2136, float %2137, float %2138, float %2139, float %2140, float %2141, float %2142, float %2143, float %2144, float %2145, float %2146, float %2147, i32 %2066, i32 %2067, i32 %2068, i32 %2069, i64 %2083, i1 true) #3, !dbg !136 + %2149 = add i32 %633, 4096, !dbg !136 + %2150 = lshr exact i32 %2149, 4, !dbg !136 + %2151 = and i32 %2150, 16383, !dbg !136 + %2152 = zext nneg i32 %2151 to i64, !dbg !136 + %2153 = or disjoint i64 %2152, 4611686293338849280, !dbg !136 + %2154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 0, !dbg !136 + %2155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 1, !dbg !136 + %2156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 2, !dbg !136 + %2157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 3, !dbg !136 + %2158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 4, !dbg !136 + %2159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 5, !dbg !136 + %2160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 6, !dbg !136 + %2161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 7, !dbg !136 + %2162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 8, !dbg !136 + %2163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 9, !dbg !136 + %2164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 10, !dbg !136 + %2165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 11, !dbg !136 + %2166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 12, !dbg !136 + %2167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 13, !dbg !136 + %2168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 14, !dbg !136 + %2169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 15, !dbg !136 + %2170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 16, !dbg !136 + %2171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 17, !dbg !136 + %2172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 18, !dbg !136 + %2173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 19, !dbg !136 + %2174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 20, !dbg !136 + %2175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 21, !dbg !136 + %2176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 22, !dbg !136 + %2177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 23, !dbg !136 + %2178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 24, !dbg !136 + %2179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 25, !dbg !136 + %2180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 26, !dbg !136 + %2181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 27, !dbg !136 + %2182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 28, !dbg !136 + %2183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 29, !dbg !136 + %2184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 30, !dbg !136 + %2185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 31, !dbg !136 + %2186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 32, !dbg !136 + %2187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 33, !dbg !136 + %2188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 34, !dbg !136 + %2189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 35, !dbg !136 + %2190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 36, !dbg !136 + %2191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 37, !dbg !136 + %2192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 38, !dbg !136 + %2193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 39, !dbg !136 + %2194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 40, !dbg !136 + %2195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 41, !dbg !136 + %2196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 42, !dbg !136 + %2197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 43, !dbg !136 + %2198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 44, !dbg !136 + %2199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 45, !dbg !136 + %2200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 46, !dbg !136 + %2201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 47, !dbg !136 + %2202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 48, !dbg !136 + %2203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 49, !dbg !136 + %2204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 50, !dbg !136 + %2205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 51, !dbg !136 + %2206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 52, !dbg !136 + %2207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 53, !dbg !136 + %2208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 54, !dbg !136 + %2209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 55, !dbg !136 + %2210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 56, !dbg !136 + %2211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 57, !dbg !136 + %2212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 58, !dbg !136 + %2213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 59, !dbg !136 + %2214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 60, !dbg !136 + %2215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 61, !dbg !136 + %2216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 62, !dbg !136 + %2217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2148, 63, !dbg !136 + %2218 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2154, float %2155, float %2156, float %2157, float %2158, float %2159, float %2160, float %2161, float %2162, float %2163, float %2164, float %2165, float %2166, float %2167, float %2168, float %2169, float %2170, float %2171, float %2172, float %2173, float %2174, float %2175, float %2176, float %2177, float %2178, float %2179, float %2180, float %2181, float %2182, float %2183, float %2184, float %2185, float %2186, float %2187, float %2188, float %2189, float %2190, float %2191, float %2192, float %2193, float %2194, float %2195, float %2196, float %2197, float %2198, float %2199, float %2200, float %2201, float %2202, float %2203, float %2204, float %2205, float %2206, float %2207, float %2208, float %2209, float %2210, float %2211, float %2212, float %2213, float %2214, float %2215, float %2216, float %2217, i32 %2070, i32 %2071, i32 %2072, i32 %2073, i64 %2153, i1 true) #3, !dbg !136 + %2219 = add i32 %633, 6144, !dbg !136 + %2220 = lshr exact i32 %2219, 4, !dbg !136 + %2221 = and i32 %2220, 16383, !dbg !136 + %2222 = zext nneg i32 %2221 to i64, !dbg !136 + %2223 = or disjoint i64 %2222, 4611686293338849280, !dbg !136 + %2224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 0, !dbg !136 + %2225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 1, !dbg !136 + %2226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 2, !dbg !136 + %2227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 3, !dbg !136 + %2228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 4, !dbg !136 + %2229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 5, !dbg !136 + %2230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 6, !dbg !136 + %2231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 7, !dbg !136 + %2232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 8, !dbg !136 + %2233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 9, !dbg !136 + %2234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 10, !dbg !136 + %2235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 11, !dbg !136 + %2236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 12, !dbg !136 + %2237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 13, !dbg !136 + %2238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 14, !dbg !136 + %2239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 15, !dbg !136 + %2240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 16, !dbg !136 + %2241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 17, !dbg !136 + %2242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 18, !dbg !136 + %2243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 19, !dbg !136 + %2244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 20, !dbg !136 + %2245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 21, !dbg !136 + %2246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 22, !dbg !136 + %2247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 23, !dbg !136 + %2248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 24, !dbg !136 + %2249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 25, !dbg !136 + %2250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 26, !dbg !136 + %2251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 27, !dbg !136 + %2252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 28, !dbg !136 + %2253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 29, !dbg !136 + %2254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 30, !dbg !136 + %2255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 31, !dbg !136 + %2256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 32, !dbg !136 + %2257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 33, !dbg !136 + %2258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 34, !dbg !136 + %2259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 35, !dbg !136 + %2260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 36, !dbg !136 + %2261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 37, !dbg !136 + %2262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 38, !dbg !136 + %2263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 39, !dbg !136 + %2264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 40, !dbg !136 + %2265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 41, !dbg !136 + %2266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 42, !dbg !136 + %2267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 43, !dbg !136 + %2268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 44, !dbg !136 + %2269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 45, !dbg !136 + %2270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 46, !dbg !136 + %2271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 47, !dbg !136 + %2272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 48, !dbg !136 + %2273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 49, !dbg !136 + %2274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 50, !dbg !136 + %2275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 51, !dbg !136 + %2276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 52, !dbg !136 + %2277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 53, !dbg !136 + %2278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 54, !dbg !136 + %2279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 55, !dbg !136 + %2280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 56, !dbg !136 + %2281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 57, !dbg !136 + %2282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 58, !dbg !136 + %2283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 59, !dbg !136 + %2284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 60, !dbg !136 + %2285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 61, !dbg !136 + %2286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 62, !dbg !136 + %2287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2218, 63, !dbg !136 + %2288 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %2224, float %2225, float %2226, float %2227, float %2228, float %2229, float %2230, float %2231, float %2232, float %2233, float %2234, float %2235, float %2236, float %2237, float %2238, float %2239, float %2240, float %2241, float %2242, float %2243, float %2244, float %2245, float %2246, float %2247, float %2248, float %2249, float %2250, float %2251, float %2252, float %2253, float %2254, float %2255, float %2256, float %2257, float %2258, float %2259, float %2260, float %2261, float %2262, float %2263, float %2264, float %2265, float %2266, float %2267, float %2268, float %2269, float %2270, float %2271, float %2272, float %2273, float %2274, float %2275, float %2276, float %2277, float %2278, float %2279, float %2280, float %2281, float %2282, float %2283, float %2284, float %2285, float %2286, float %2287, i32 %2074, i32 %2075, i32 %2076, i32 %2077, i64 %2223, i1 true) #3, !dbg !136 + %2289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 0, !dbg !136 + %2290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 1, !dbg !136 + %2291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 2, !dbg !136 + %2292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 3, !dbg !136 + %2293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 4, !dbg !136 + %2294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 5, !dbg !136 + %2295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 6, !dbg !136 + %2296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 7, !dbg !136 + %2297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 8, !dbg !136 + %2298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 9, !dbg !136 + %2299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 10, !dbg !136 + %2300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 11, !dbg !136 + %2301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 12, !dbg !136 + %2302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 13, !dbg !136 + %2303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 14, !dbg !136 + %2304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 15, !dbg !136 + %2305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 16, !dbg !136 + %2306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 17, !dbg !136 + %2307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 18, !dbg !136 + %2308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 19, !dbg !136 + %2309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 20, !dbg !136 + %2310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 21, !dbg !136 + %2311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 22, !dbg !136 + %2312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 23, !dbg !136 + %2313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 24, !dbg !136 + %2314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 25, !dbg !136 + %2315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 26, !dbg !136 + %2316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 27, !dbg !136 + %2317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 28, !dbg !136 + %2318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 29, !dbg !136 + %2319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 30, !dbg !136 + %2320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 31, !dbg !136 + %2321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 32, !dbg !136 + %2322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 33, !dbg !136 + %2323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 34, !dbg !136 + %2324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 35, !dbg !136 + %2325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 36, !dbg !136 + %2326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 37, !dbg !136 + %2327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 38, !dbg !136 + %2328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 39, !dbg !136 + %2329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 40, !dbg !136 + %2330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 41, !dbg !136 + %2331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 42, !dbg !136 + %2332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 43, !dbg !136 + %2333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 44, !dbg !136 + %2334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 45, !dbg !136 + %2335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 46, !dbg !136 + %2336 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 47, !dbg !136 + %2337 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 48, !dbg !136 + %2338 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 49, !dbg !136 + %2339 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 50, !dbg !136 + %2340 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 51, !dbg !136 + %2341 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 52, !dbg !136 + %2342 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 53, !dbg !136 + %2343 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 54, !dbg !136 + %2344 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 55, !dbg !136 + %2345 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 56, !dbg !136 + %2346 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 57, !dbg !136 + %2347 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 58, !dbg !136 + %2348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 59, !dbg !136 + %2349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 60, !dbg !136 + %2350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 61, !dbg !136 + %2351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 62, !dbg !136 + %2352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2288, 63, !dbg !136 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !136 + %2353 = insertelement <2 x i32> poison, i32 %534, i64 0, !dbg !106 + %2354 = shufflevector <2 x i32> %2353, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !106 + %2355 = add <2 x i32> %2354, %609, !dbg !106 + %2356 = add <2 x i32> %2354, %608, !dbg !106 + %2357 = add <2 x i32> %2354, %607, !dbg !106 + %2358 = add <2 x i32> %2354, %606, !dbg !106 + %2359 = add <2 x i32> %2354, %605, !dbg !106 + %2360 = add <2 x i32> %2354, %604, !dbg !106 + %2361 = add <2 x i32> %2354, %603, !dbg !106 + %2362 = add <2 x i32> %2354, %602, !dbg !106 + %2363 = add nuw nsw i32 %601, 1, !dbg !101 + %2364 = lshr i32 %2363, 1, !dbg !137 + %2365 = zext nneg i32 %2364 to i64, !dbg !138 + %2366 = getelementptr i32, ptr addrspace(1) %369, i64 %2365, !dbg !138 + %2367 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !139 + %2368 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2366, i64 %2367, i1 %611) #3, !dbg !139 + %2369 = add nuw nsw i32 %2364, 1, !dbg !140 + %2370 = icmp slt i32 %2369, %373, !dbg !141 + %2371 = getelementptr i8, ptr addrspace(1) %2366, i64 4, !dbg !142 + %2372 = and i1 %611, %2370, !dbg !101 + %2373 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !143 + %2374 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %2371, i64 %2373, i1 %2372) #3, !dbg !143 + %2375 = and i32 %601, 1, !dbg !144 + %2376 = sub i32 %2374, %2368, !dbg !145 + %2377 = shl i32 %2376, 7, !dbg !146 + %2378 = add i32 %2377, -64, !dbg !147 + %2379 = xor i32 %2375, 1, !dbg !148 + %2380 = mul nuw nsw i32 %2378, %2379, !dbg !148 + %2381 = shl nuw nsw i32 %2375, 6, !dbg !149 + %2382 = add i32 %2380, %2381, !dbg !150 + %2383 = shl i32 %2382, 10, !dbg !151 + %2384 = sext i32 %2383 to i64, !dbg !104 + %2385 = getelementptr bfloat, ptr addrspace(1) %.pn9091615, i64 %2384, !dbg !104 + %2386 = getelementptr bfloat, ptr addrspace(1) %.pn8931616, i64 %2384, !dbg !104 + %2387 = getelementptr bfloat, ptr addrspace(1) %.pn8771617, i64 %2384, !dbg !104 + %2388 = getelementptr bfloat, ptr addrspace(1) %.pn8611618, i64 %2384, !dbg !104 + %2389 = getelementptr bfloat, ptr addrspace(1) %.pn9811623, i64 %2384, !dbg !105 + %2390 = getelementptr bfloat, ptr addrspace(1) %.pn9651624, i64 %2384, !dbg !105 + %2391 = getelementptr bfloat, ptr addrspace(1) %.pn9491625, i64 %2384, !dbg !105 + %2392 = getelementptr bfloat, ptr addrspace(1) %.pn9331626, i64 %2384, !dbg !105 + %2393 = add i32 %2382, %.pn9171619, !dbg !106 + %2394 = add i32 %2382, %.pn9151620, !dbg !106 + %2395 = add i32 %2382, %.pn9131621, !dbg !106 + %2396 = add i32 %2382, %.pn9111622, !dbg !106 + %2397 = add i32 %536, 1, !dbg !101 + %2398 = icmp sgt i32 %2397, 2, !dbg !101 + %2399 = select i1 %2398, i32 0, i32 %2397, !dbg !101 + %2400 = icmp slt i32 %2393, %18, !dbg !102 + %2401 = icmp slt i32 %2394, %18, !dbg !102 + %2402 = icmp slt i32 %2395, %18, !dbg !102 + %2403 = icmp slt i32 %2396, %18, !dbg !102 + %2404 = shl i32 %2399, 13, !dbg !103 + %2405 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2404, !dbg !103 + %2406 = and i1 %610, %2400, !dbg !101 + %2407 = and i1 %610, %2401, !dbg !101 + %2408 = and i1 %610, %2402, !dbg !101 + %2409 = and i1 %610, %2403, !dbg !101 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !103 + %2410 = getelementptr inbounds nuw i8, ptr addrspace(3) %2405, i32 %429, !dbg !103 + %2411 = select i1 %2406, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2410, ptr addrspace(1) %2385, i32 %2411) #3, !dbg !103 + %2412 = getelementptr inbounds nuw i8, ptr addrspace(3) %2405, i32 %432, !dbg !103 + %2413 = select i1 %2407, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2412, ptr addrspace(1) %2386, i32 %2413) #3, !dbg !103 + %2414 = getelementptr inbounds nuw i8, ptr addrspace(3) %2405, i32 %435, !dbg !103 + %2415 = select i1 %2408, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2414, ptr addrspace(1) %2387, i32 %2415) #3, !dbg !103 + %2416 = getelementptr inbounds nuw i8, ptr addrspace(3) %2405, i32 %438, !dbg !103 + %2417 = select i1 %2409, i32 16, i32 0, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2416, ptr addrspace(1) %2388, i32 %2417) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + %2418 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2404, !dbg !103 + %2419 = getelementptr inbounds nuw i8, ptr addrspace(3) %2418, i32 %429, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %2419, ptr addrspace(1) %2389, i32 %2411) #3, !dbg !103 + %2420 = getelementptr inbounds nuw i8, ptr addrspace(3) %2418, i32 %432, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2420, ptr addrspace(1) %2390, i32 %2413) #3, !dbg !103 + %2421 = getelementptr inbounds nuw i8, ptr addrspace(3) %2418, i32 %435, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2421, ptr addrspace(1) %2391, i32 %2415) #3, !dbg !103 + %2422 = getelementptr inbounds nuw i8, ptr addrspace(3) %2418, i32 %438, !dbg !103 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %2422, ptr addrspace(1) %2392, i32 %2417) #3, !dbg !103 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !103 + %exitcond.not = icmp eq i32 %2363, %smax, !dbg !101 + br i1 %exitcond.not, label %._crit_edge, label %533, !dbg !101 + +._crit_edge: ; preds = %__nv_exp2f.exit1609, %61 + %2423 = phi float [ 0.000000e+00, %61 ], [ %2289, %__nv_exp2f.exit1609 ] + %2424 = phi float [ 0.000000e+00, %61 ], [ %2290, %__nv_exp2f.exit1609 ] + %2425 = phi float [ 0.000000e+00, %61 ], [ %2291, %__nv_exp2f.exit1609 ] + %2426 = phi float [ 0.000000e+00, %61 ], [ %2292, %__nv_exp2f.exit1609 ] + %2427 = phi float [ 0.000000e+00, %61 ], [ %2293, %__nv_exp2f.exit1609 ] + %2428 = phi float [ 0.000000e+00, %61 ], [ %2294, %__nv_exp2f.exit1609 ] + %2429 = phi float [ 0.000000e+00, %61 ], [ %2295, %__nv_exp2f.exit1609 ] + %2430 = phi float [ 0.000000e+00, %61 ], [ %2296, %__nv_exp2f.exit1609 ] + %2431 = phi float [ 0.000000e+00, %61 ], [ %2297, %__nv_exp2f.exit1609 ] + %2432 = phi float [ 0.000000e+00, %61 ], [ %2298, %__nv_exp2f.exit1609 ] + %2433 = phi float [ 0.000000e+00, %61 ], [ %2299, %__nv_exp2f.exit1609 ] + %2434 = phi float [ 0.000000e+00, %61 ], [ %2300, %__nv_exp2f.exit1609 ] + %2435 = phi float [ 0.000000e+00, %61 ], [ %2301, %__nv_exp2f.exit1609 ] + %2436 = phi float [ 0.000000e+00, %61 ], [ %2302, %__nv_exp2f.exit1609 ] + %2437 = phi float [ 0.000000e+00, %61 ], [ %2303, %__nv_exp2f.exit1609 ] + %2438 = phi float [ 0.000000e+00, %61 ], [ %2304, %__nv_exp2f.exit1609 ] + %2439 = phi float [ 0.000000e+00, %61 ], [ %2305, %__nv_exp2f.exit1609 ] + %2440 = phi float [ 0.000000e+00, %61 ], [ %2306, %__nv_exp2f.exit1609 ] + %2441 = phi float [ 0.000000e+00, %61 ], [ %2307, %__nv_exp2f.exit1609 ] + %2442 = phi float [ 0.000000e+00, %61 ], [ %2308, %__nv_exp2f.exit1609 ] + %2443 = phi float [ 0.000000e+00, %61 ], [ %2309, %__nv_exp2f.exit1609 ] + %2444 = phi float [ 0.000000e+00, %61 ], [ %2310, %__nv_exp2f.exit1609 ] + %2445 = phi float [ 0.000000e+00, %61 ], [ %2311, %__nv_exp2f.exit1609 ] + %2446 = phi float [ 0.000000e+00, %61 ], [ %2312, %__nv_exp2f.exit1609 ] + %2447 = phi float [ 0.000000e+00, %61 ], [ %2313, %__nv_exp2f.exit1609 ] + %2448 = phi float [ 0.000000e+00, %61 ], [ %2314, %__nv_exp2f.exit1609 ] + %2449 = phi float [ 0.000000e+00, %61 ], [ %2315, %__nv_exp2f.exit1609 ] + %2450 = phi float [ 0.000000e+00, %61 ], [ %2316, %__nv_exp2f.exit1609 ] + %2451 = phi float [ 0.000000e+00, %61 ], [ %2317, %__nv_exp2f.exit1609 ] + %2452 = phi float [ 0.000000e+00, %61 ], [ %2318, %__nv_exp2f.exit1609 ] + %2453 = phi float [ 0.000000e+00, %61 ], [ %2319, %__nv_exp2f.exit1609 ] + %2454 = phi float [ 0.000000e+00, %61 ], [ %2320, %__nv_exp2f.exit1609 ] + %2455 = phi float [ 0.000000e+00, %61 ], [ %2321, %__nv_exp2f.exit1609 ] + %2456 = phi float [ 0.000000e+00, %61 ], [ %2322, %__nv_exp2f.exit1609 ] + %2457 = phi float [ 0.000000e+00, %61 ], [ %2323, %__nv_exp2f.exit1609 ] + %2458 = phi float [ 0.000000e+00, %61 ], [ %2324, %__nv_exp2f.exit1609 ] + %2459 = phi float [ 0.000000e+00, %61 ], [ %2325, %__nv_exp2f.exit1609 ] + %2460 = phi float [ 0.000000e+00, %61 ], [ %2326, %__nv_exp2f.exit1609 ] + %2461 = phi float [ 0.000000e+00, %61 ], [ %2327, %__nv_exp2f.exit1609 ] + %2462 = phi float [ 0.000000e+00, %61 ], [ %2328, %__nv_exp2f.exit1609 ] + %2463 = phi float [ 0.000000e+00, %61 ], [ %2329, %__nv_exp2f.exit1609 ] + %2464 = phi float [ 0.000000e+00, %61 ], [ %2330, %__nv_exp2f.exit1609 ] + %2465 = phi float [ 0.000000e+00, %61 ], [ %2331, %__nv_exp2f.exit1609 ] + %2466 = phi float [ 0.000000e+00, %61 ], [ %2332, %__nv_exp2f.exit1609 ] + %2467 = phi float [ 0.000000e+00, %61 ], [ %2333, %__nv_exp2f.exit1609 ] + %2468 = phi float [ 0.000000e+00, %61 ], [ %2334, %__nv_exp2f.exit1609 ] + %2469 = phi float [ 0.000000e+00, %61 ], [ %2335, %__nv_exp2f.exit1609 ] + %2470 = phi float [ 0.000000e+00, %61 ], [ %2336, %__nv_exp2f.exit1609 ] + %2471 = phi float [ 0.000000e+00, %61 ], [ %2337, %__nv_exp2f.exit1609 ] + %2472 = phi float [ 0.000000e+00, %61 ], [ %2338, %__nv_exp2f.exit1609 ] + %2473 = phi float [ 0.000000e+00, %61 ], [ %2339, %__nv_exp2f.exit1609 ] + %2474 = phi float [ 0.000000e+00, %61 ], [ %2340, %__nv_exp2f.exit1609 ] + %2475 = phi float [ 0.000000e+00, %61 ], [ %2341, %__nv_exp2f.exit1609 ] + %2476 = phi float [ 0.000000e+00, %61 ], [ %2342, %__nv_exp2f.exit1609 ] + %2477 = phi float [ 0.000000e+00, %61 ], [ %2343, %__nv_exp2f.exit1609 ] + %2478 = phi float [ 0.000000e+00, %61 ], [ %2344, %__nv_exp2f.exit1609 ] + %2479 = phi float [ 0.000000e+00, %61 ], [ %2345, %__nv_exp2f.exit1609 ] + %2480 = phi float [ 0.000000e+00, %61 ], [ %2346, %__nv_exp2f.exit1609 ] + %2481 = phi float [ 0.000000e+00, %61 ], [ %2347, %__nv_exp2f.exit1609 ] + %2482 = phi float [ 0.000000e+00, %61 ], [ %2348, %__nv_exp2f.exit1609 ] + %2483 = phi float [ 0.000000e+00, %61 ], [ %2349, %__nv_exp2f.exit1609 ] + %2484 = phi float [ 0.000000e+00, %61 ], [ %2350, %__nv_exp2f.exit1609 ] + %2485 = phi float [ 0.000000e+00, %61 ], [ %2351, %__nv_exp2f.exit1609 ] + %2486 = phi float [ 0.000000e+00, %61 ], [ %2352, %__nv_exp2f.exit1609 ] + %2487 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %2423, float %2424, float %2425, float %2426, float %2427, float %2428, float %2429, float %2430, float %2431, float %2432, float %2433, float %2434, float %2435, float %2436, float %2437, float %2438, float %2439, float %2440, float %2441, float %2442, float %2443, float %2444, float %2445, float %2446, float %2447, float %2448, float %2449, float %2450, float %2451, float %2452, float %2453, float %2454, float %2455, float %2456, float %2457, float %2458, float %2459, float %2460, float %2461, float %2462, float %2463, float %2464, float %2465, float %2466, float %2467, float %2468, float %2469, float %2470, float %2471, float %2472, float %2473, float %2474, float %2475, float %2476, float %2477, float %2478, float %2479, float %2480, float %2481, float %2482, float %2483, float %2484, float %2485, float %2486) #3, !dbg !101 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !101 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !101 + %2488 = getelementptr i32, ptr addrspace(1) %13, i64 %368, !dbg !152 + %2489 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2488) #3, !dbg !153 + %2490 = shl i32 %2489, 7, !dbg !154 + %2491 = getelementptr i32, ptr addrspace(1) %12, i64 %368, !dbg !155 + %2492 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l"(ptr addrspace(1) %2491) #3, !dbg !156 + %2493 = or disjoint i32 %2490, %47, !dbg !157 + %2494 = or disjoint i32 %2490, %48, !dbg !157 + %2495 = or disjoint i32 %2490, %49, !dbg !157 + %2496 = or disjoint i32 %2490, %50, !dbg !157 + %2497 = shl i32 %2493, 10, !dbg !158 + %2498 = shl i32 %2494, 10, !dbg !158 + %2499 = shl i32 %2495, 10, !dbg !158 + %2500 = shl i32 %2496, 10, !dbg !158 + %2501 = sext i32 %2497 to i64, !dbg !160 + %2502 = getelementptr bfloat, ptr addrspace(1) %41, i64 %2501, !dbg !160 + %2503 = sext i32 %2498 to i64, !dbg !160 + %2504 = getelementptr bfloat, ptr addrspace(1) %41, i64 %2503, !dbg !160 + %2505 = sext i32 %2499 to i64, !dbg !160 + %2506 = getelementptr bfloat, ptr addrspace(1) %41, i64 %2505, !dbg !160 + %2507 = sext i32 %2500 to i64, !dbg !160 + %2508 = getelementptr bfloat, ptr addrspace(1) %41, i64 %2507, !dbg !160 + %2509 = getelementptr bfloat, ptr addrspace(1) %2502, i64 %123, !dbg !161 + %2510 = getelementptr bfloat, ptr addrspace(1) %2504, i64 %123, !dbg !161 + %2511 = getelementptr bfloat, ptr addrspace(1) %2506, i64 %123, !dbg !161 + %2512 = getelementptr bfloat, ptr addrspace(1) %2508, i64 %123, !dbg !161 + %2513 = getelementptr bfloat, ptr addrspace(1) %42, i64 %2501, !dbg !162 + %2514 = getelementptr bfloat, ptr addrspace(1) %42, i64 %2503, !dbg !162 + %2515 = getelementptr bfloat, ptr addrspace(1) %42, i64 %2505, !dbg !162 + %2516 = getelementptr bfloat, ptr addrspace(1) %42, i64 %2507, !dbg !162 + %2517 = getelementptr bfloat, ptr addrspace(1) %2513, i64 %123, !dbg !163 + %2518 = getelementptr bfloat, ptr addrspace(1) %2514, i64 %123, !dbg !163 + %2519 = getelementptr bfloat, ptr addrspace(1) %2515, i64 %123, !dbg !163 + %2520 = getelementptr bfloat, ptr addrspace(1) %2516, i64 %123, !dbg !163 + %2521 = shl i32 %2492, 1, !dbg !164 + %2522 = tail call i32 @llvm.smin.i32(i32 %2521, i32 %417), !dbg !165 + %2523 = icmp sgt i32 %2521, 0, !dbg !166 + %2524 = icmp slt i32 %2493, %18, !dbg !167 + %2525 = icmp slt i32 %2494, %18, !dbg !167 + %2526 = icmp slt i32 %2495, %18, !dbg !167 + %2527 = icmp slt i32 %2496, %18, !dbg !167 + %2528 = and i1 %2523, %2524, !dbg !166 + %2529 = and i1 %2523, %2525, !dbg !166 + %2530 = and i1 %2523, %2526, !dbg !166 + %2531 = and i1 %2523, %2527, !dbg !166 + %2532 = select i1 %2528, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %430, ptr addrspace(1) %2509, i32 %2532) #3, !dbg !168 + %2533 = select i1 %2529, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %433, ptr addrspace(1) %2510, i32 %2533) #3, !dbg !168 + %2534 = select i1 %2530, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %436, ptr addrspace(1) %2511, i32 %2534) #3, !dbg !168 + %2535 = select i1 %2531, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %439, ptr addrspace(1) %2512, i32 %2535) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %441, ptr addrspace(1) %2517, i32 %2532) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %442, ptr addrspace(1) %2518, i32 %2533) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %443, ptr addrspace(1) %2519, i32 %2534) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %444, ptr addrspace(1) %2520, i32 %2535) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + %2536 = icmp sgt i32 %2522, 1, !dbg !166 + %2537 = getelementptr i8, ptr addrspace(1) %2509, i64 131072, !dbg !169 + %2538 = getelementptr i8, ptr addrspace(1) %2510, i64 131072, !dbg !169 + %2539 = getelementptr i8, ptr addrspace(1) %2511, i64 131072, !dbg !169 + %2540 = getelementptr i8, ptr addrspace(1) %2512, i64 131072, !dbg !169 + %2541 = getelementptr i8, ptr addrspace(1) %2517, i64 131072, !dbg !170 + %2542 = getelementptr i8, ptr addrspace(1) %2518, i64 131072, !dbg !170 + %2543 = getelementptr i8, ptr addrspace(1) %2519, i64 131072, !dbg !170 + %2544 = getelementptr i8, ptr addrspace(1) %2520, i64 131072, !dbg !170 + %2545 = or disjoint i32 %2493, 64, !dbg !171 + %2546 = or disjoint i32 %2494, 64, !dbg !171 + %2547 = or disjoint i32 %2495, 64, !dbg !171 + %2548 = or disjoint i32 %2496, 64, !dbg !171 + %2549 = icmp slt i32 %2545, %18, !dbg !167 + %2550 = icmp slt i32 %2546, %18, !dbg !167 + %2551 = icmp slt i32 %2547, %18, !dbg !167 + %2552 = icmp slt i32 %2548, %18, !dbg !167 + %2553 = and i1 %2536, %2549, !dbg !166 + %2554 = and i1 %2536, %2550, !dbg !166 + %2555 = and i1 %2536, %2551, !dbg !166 + %2556 = and i1 %2536, %2552, !dbg !166 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !168 + %2557 = select i1 %2553, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %466, ptr addrspace(1) %2537, i32 %2557) #3, !dbg !168 + %2558 = select i1 %2554, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %468, ptr addrspace(1) %2538, i32 %2558) #3, !dbg !168 + %2559 = select i1 %2555, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %470, ptr addrspace(1) %2539, i32 %2559) #3, !dbg !168 + %2560 = select i1 %2556, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %472, ptr addrspace(1) %2540, i32 %2560) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %474, ptr addrspace(1) %2541, i32 %2557) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %475, ptr addrspace(1) %2542, i32 %2558) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %476, ptr addrspace(1) %2543, i32 %2559) #3, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %477, ptr addrspace(1) %2544, i32 %2560) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !172 + %2561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 0, !dbg !166 + %2562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 1, !dbg !166 + %2563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 2, !dbg !166 + %2564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 3, !dbg !166 + %2565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 4, !dbg !166 + %2566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 5, !dbg !166 + %2567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 6, !dbg !166 + %2568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 7, !dbg !166 + %2569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 8, !dbg !166 + %2570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 9, !dbg !166 + %2571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 10, !dbg !166 + %2572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 11, !dbg !166 + %2573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 12, !dbg !166 + %2574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 13, !dbg !166 + %2575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 14, !dbg !166 + %2576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 15, !dbg !166 + %2577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 16, !dbg !166 + %2578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 17, !dbg !166 + %2579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 18, !dbg !166 + %2580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 19, !dbg !166 + %2581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 20, !dbg !166 + %2582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 21, !dbg !166 + %2583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 22, !dbg !166 + %2584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 23, !dbg !166 + %2585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 24, !dbg !166 + %2586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 25, !dbg !166 + %2587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 26, !dbg !166 + %2588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 27, !dbg !166 + %2589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 28, !dbg !166 + %2590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 29, !dbg !166 + %2591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 30, !dbg !166 + %2592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 31, !dbg !166 + %2593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 32, !dbg !166 + %2594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 33, !dbg !166 + %2595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 34, !dbg !166 + %2596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 35, !dbg !166 + %2597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 36, !dbg !166 + %2598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 37, !dbg !166 + %2599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 38, !dbg !166 + %2600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 39, !dbg !166 + %2601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 40, !dbg !166 + %2602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 41, !dbg !166 + %2603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 42, !dbg !166 + %2604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 43, !dbg !166 + %2605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 44, !dbg !166 + %2606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 45, !dbg !166 + %2607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 46, !dbg !166 + %2608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 47, !dbg !166 + %2609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 48, !dbg !166 + %2610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 49, !dbg !166 + %2611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 50, !dbg !166 + %2612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 51, !dbg !166 + %2613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 52, !dbg !166 + %2614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 53, !dbg !166 + %2615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 54, !dbg !166 + %2616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 55, !dbg !166 + %2617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 56, !dbg !166 + %2618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 57, !dbg !166 + %2619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 58, !dbg !166 + %2620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 59, !dbg !166 + %2621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 60, !dbg !166 + %2622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 61, !dbg !166 + %2623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 62, !dbg !166 + %2624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 63, !dbg !166 + br i1 %2523, label %.lr.ph1672, label %._crit_edge1673, !dbg !166 + +.lr.ph1672: ; preds = %._crit_edge + %2625 = insertelement <16 x i32> poison, i32 %2490, i64 0, !dbg !157 + %2626 = shufflevector <16 x i32> %2625, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !157 + %2627 = shufflevector <2 x i32> %379, <2 x i32> poison, <16 x i32> , !dbg !157 + %2628 = insertelement <16 x i32> %2627, i32 %376, i64 14, !dbg !157 + %2629 = insertelement <16 x i32> %2628, i32 %375, i64 15, !dbg !157 + %2630 = shufflevector <8 x i32> %385, <8 x i32> poison, <16 x i32> , !dbg !157 + %2631 = shufflevector <16 x i32> %2630, <16 x i32> %2629, <16 x i32> , !dbg !157 + %2632 = shufflevector <4 x i32> %382, <4 x i32> poison, <16 x i32> , !dbg !157 + %2633 = shufflevector <16 x i32> %2631, <16 x i32> %2632, <16 x i32> , !dbg !157 + %2634 = or disjoint <16 x i32> %2626, %2633, !dbg !157 + %2635 = add nsw i32 %2522, -2 + %2636 = add nsw i32 %2522, -1 + %smax2263 = tail call i32 @llvm.smax.i32(i32 %2522, i32 1), !dbg !166 + %2637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 0, !dbg !166 + %2638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 1, !dbg !166 + %2639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 2, !dbg !166 + %2640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 3, !dbg !166 + %2641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 4, !dbg !166 + %2642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 5, !dbg !166 + %2643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 6, !dbg !166 + %2644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 7, !dbg !166 + %2645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 8, !dbg !166 + %2646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 9, !dbg !166 + %2647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 10, !dbg !166 + %2648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 11, !dbg !166 + %2649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 12, !dbg !166 + %2650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 13, !dbg !166 + %2651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 14, !dbg !166 + %2652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 15, !dbg !166 + %2653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 16, !dbg !166 + %2654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 17, !dbg !166 + %2655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 18, !dbg !166 + %2656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 19, !dbg !166 + %2657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 20, !dbg !166 + %2658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 21, !dbg !166 + %2659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 22, !dbg !166 + %2660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 23, !dbg !166 + %2661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 24, !dbg !166 + %2662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 25, !dbg !166 + %2663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 26, !dbg !166 + %2664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 27, !dbg !166 + %2665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 28, !dbg !166 + %2666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 29, !dbg !166 + %2667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 30, !dbg !166 + %2668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 31, !dbg !166 + %2669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 32, !dbg !166 + %2670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 33, !dbg !166 + %2671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 34, !dbg !166 + %2672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 35, !dbg !166 + %2673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 36, !dbg !166 + %2674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 37, !dbg !166 + %2675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 38, !dbg !166 + %2676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 39, !dbg !166 + %2677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 40, !dbg !166 + %2678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 41, !dbg !166 + %2679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 42, !dbg !166 + %2680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 43, !dbg !166 + %2681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 44, !dbg !166 + %2682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 45, !dbg !166 + %2683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 46, !dbg !166 + %2684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 47, !dbg !166 + %2685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 48, !dbg !166 + %2686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 49, !dbg !166 + %2687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 50, !dbg !166 + %2688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 51, !dbg !166 + %2689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 52, !dbg !166 + %2690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 53, !dbg !166 + %2691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 54, !dbg !166 + %2692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 55, !dbg !166 + %2693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 56, !dbg !166 + %2694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 57, !dbg !166 + %2695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 58, !dbg !166 + %2696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 59, !dbg !166 + %2697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 60, !dbg !166 + %2698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 61, !dbg !166 + %2699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 62, !dbg !166 + %2700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2487, 63, !dbg !166 + br label %2701, !dbg !166 + +2701: ; preds = %.lr.ph1672, %__nv_exp2f.exit1513 + %2702 = phi i32 [ 64, %.lr.ph1672 ], [ %4272, %__nv_exp2f.exit1513 ] + %2703 = phi i32 [ -1, %.lr.ph1672 ], [ %2711, %__nv_exp2f.exit1513 ] + %2704 = phi i32 [ 1, %.lr.ph1672 ], [ %4289, %__nv_exp2f.exit1513 ] + %.pn11011654 = phi ptr addrspace(1) [ %2544, %.lr.ph1672 ], [ %4282, %__nv_exp2f.exit1513 ] + %.pn11171653 = phi ptr addrspace(1) [ %2543, %.lr.ph1672 ], [ %4281, %__nv_exp2f.exit1513 ] + %.pn11331652 = phi ptr addrspace(1) [ %2542, %.lr.ph1672 ], [ %4280, %__nv_exp2f.exit1513 ] + %.pn11491651 = phi ptr addrspace(1) [ %2541, %.lr.ph1672 ], [ %4279, %__nv_exp2f.exit1513 ] + %.pn10791650 = phi i32 [ %2548, %.lr.ph1672 ], [ %4286, %__nv_exp2f.exit1513 ] + %.pn10811649 = phi i32 [ %2547, %.lr.ph1672 ], [ %4285, %__nv_exp2f.exit1513 ] + %.pn10831648 = phi i32 [ %2546, %.lr.ph1672 ], [ %4284, %__nv_exp2f.exit1513 ] + %.pn10851647 = phi i32 [ %2545, %.lr.ph1672 ], [ %4283, %__nv_exp2f.exit1513 ] + %.pn10291646 = phi ptr addrspace(1) [ %2540, %.lr.ph1672 ], [ %4278, %__nv_exp2f.exit1513 ] + %.pn10451645 = phi ptr addrspace(1) [ %2539, %.lr.ph1672 ], [ %4277, %__nv_exp2f.exit1513 ] + %.pn10611644 = phi ptr addrspace(1) [ %2538, %.lr.ph1672 ], [ %4276, %__nv_exp2f.exit1513 ] + %.pn10771643 = phi ptr addrspace(1) [ %2537, %.lr.ph1672 ], [ %4275, %__nv_exp2f.exit1513 ] + %.pn = phi float [ %2637, %.lr.ph1672 ], [ %4186, %__nv_exp2f.exit1513 ] + %.pn2530 = phi float [ %2638, %.lr.ph1672 ], [ %4187, %__nv_exp2f.exit1513 ] + %.pn2531 = phi float [ %2639, %.lr.ph1672 ], [ %4188, %__nv_exp2f.exit1513 ] + %.pn2532 = phi float [ %2640, %.lr.ph1672 ], [ %4189, %__nv_exp2f.exit1513 ] + %.pn2533 = phi float [ %2641, %.lr.ph1672 ], [ %4190, %__nv_exp2f.exit1513 ] + %.pn2534 = phi float [ %2642, %.lr.ph1672 ], [ %4191, %__nv_exp2f.exit1513 ] + %.pn2535 = phi float [ %2643, %.lr.ph1672 ], [ %4192, %__nv_exp2f.exit1513 ] + %.pn2536 = phi float [ %2644, %.lr.ph1672 ], [ %4193, %__nv_exp2f.exit1513 ] + %.pn2537 = phi float [ %2645, %.lr.ph1672 ], [ %4194, %__nv_exp2f.exit1513 ] + %.pn2538 = phi float [ %2646, %.lr.ph1672 ], [ %4195, %__nv_exp2f.exit1513 ] + %.pn2539 = phi float [ %2647, %.lr.ph1672 ], [ %4196, %__nv_exp2f.exit1513 ] + %.pn2540 = phi float [ %2648, %.lr.ph1672 ], [ %4197, %__nv_exp2f.exit1513 ] + %.pn2541 = phi float [ %2649, %.lr.ph1672 ], [ %4198, %__nv_exp2f.exit1513 ] + %.pn2542 = phi float [ %2650, %.lr.ph1672 ], [ %4199, %__nv_exp2f.exit1513 ] + %.pn2543 = phi float [ %2651, %.lr.ph1672 ], [ %4200, %__nv_exp2f.exit1513 ] + %.pn2544 = phi float [ %2652, %.lr.ph1672 ], [ %4201, %__nv_exp2f.exit1513 ] + %.pn2545 = phi float [ %2653, %.lr.ph1672 ], [ %4202, %__nv_exp2f.exit1513 ] + %.pn2546 = phi float [ %2654, %.lr.ph1672 ], [ %4203, %__nv_exp2f.exit1513 ] + %.pn2547 = phi float [ %2655, %.lr.ph1672 ], [ %4204, %__nv_exp2f.exit1513 ] + %.pn2548 = phi float [ %2656, %.lr.ph1672 ], [ %4205, %__nv_exp2f.exit1513 ] + %.pn2549 = phi float [ %2657, %.lr.ph1672 ], [ %4206, %__nv_exp2f.exit1513 ] + %.pn2550 = phi float [ %2658, %.lr.ph1672 ], [ %4207, %__nv_exp2f.exit1513 ] + %.pn2551 = phi float [ %2659, %.lr.ph1672 ], [ %4208, %__nv_exp2f.exit1513 ] + %.pn2552 = phi float [ %2660, %.lr.ph1672 ], [ %4209, %__nv_exp2f.exit1513 ] + %.pn2553 = phi float [ %2661, %.lr.ph1672 ], [ %4210, %__nv_exp2f.exit1513 ] + %.pn2554 = phi float [ %2662, %.lr.ph1672 ], [ %4211, %__nv_exp2f.exit1513 ] + %.pn2555 = phi float [ %2663, %.lr.ph1672 ], [ %4212, %__nv_exp2f.exit1513 ] + %.pn2556 = phi float [ %2664, %.lr.ph1672 ], [ %4213, %__nv_exp2f.exit1513 ] + %.pn2557 = phi float [ %2665, %.lr.ph1672 ], [ %4214, %__nv_exp2f.exit1513 ] + %.pn2558 = phi float [ %2666, %.lr.ph1672 ], [ %4215, %__nv_exp2f.exit1513 ] + %.pn2559 = phi float [ %2667, %.lr.ph1672 ], [ %4216, %__nv_exp2f.exit1513 ] + %.pn2560 = phi float [ %2668, %.lr.ph1672 ], [ %4217, %__nv_exp2f.exit1513 ] + %.pn2561 = phi float [ %2669, %.lr.ph1672 ], [ %4218, %__nv_exp2f.exit1513 ] + %.pn2562 = phi float [ %2670, %.lr.ph1672 ], [ %4219, %__nv_exp2f.exit1513 ] + %.pn2563 = phi float [ %2671, %.lr.ph1672 ], [ %4220, %__nv_exp2f.exit1513 ] + %.pn2564 = phi float [ %2672, %.lr.ph1672 ], [ %4221, %__nv_exp2f.exit1513 ] + %.pn2565 = phi float [ %2673, %.lr.ph1672 ], [ %4222, %__nv_exp2f.exit1513 ] + %.pn2566 = phi float [ %2674, %.lr.ph1672 ], [ %4223, %__nv_exp2f.exit1513 ] + %.pn2567 = phi float [ %2675, %.lr.ph1672 ], [ %4224, %__nv_exp2f.exit1513 ] + %.pn2568 = phi float [ %2676, %.lr.ph1672 ], [ %4225, %__nv_exp2f.exit1513 ] + %.pn2569 = phi float [ %2677, %.lr.ph1672 ], [ %4226, %__nv_exp2f.exit1513 ] + %.pn2570 = phi float [ %2678, %.lr.ph1672 ], [ %4227, %__nv_exp2f.exit1513 ] + %.pn2571 = phi float [ %2679, %.lr.ph1672 ], [ %4228, %__nv_exp2f.exit1513 ] + %.pn2572 = phi float [ %2680, %.lr.ph1672 ], [ %4229, %__nv_exp2f.exit1513 ] + %.pn2573 = phi float [ %2681, %.lr.ph1672 ], [ %4230, %__nv_exp2f.exit1513 ] + %.pn2574 = phi float [ %2682, %.lr.ph1672 ], [ %4231, %__nv_exp2f.exit1513 ] + %.pn2575 = phi float [ %2683, %.lr.ph1672 ], [ %4232, %__nv_exp2f.exit1513 ] + %.pn2576 = phi float [ %2684, %.lr.ph1672 ], [ %4233, %__nv_exp2f.exit1513 ] + %.pn2577 = phi float [ %2685, %.lr.ph1672 ], [ %4234, %__nv_exp2f.exit1513 ] + %.pn2578 = phi float [ %2686, %.lr.ph1672 ], [ %4235, %__nv_exp2f.exit1513 ] + %.pn2579 = phi float [ %2687, %.lr.ph1672 ], [ %4236, %__nv_exp2f.exit1513 ] + %.pn2580 = phi float [ %2688, %.lr.ph1672 ], [ %4237, %__nv_exp2f.exit1513 ] + %.pn2581 = phi float [ %2689, %.lr.ph1672 ], [ %4238, %__nv_exp2f.exit1513 ] + %.pn2582 = phi float [ %2690, %.lr.ph1672 ], [ %4239, %__nv_exp2f.exit1513 ] + %.pn2583 = phi float [ %2691, %.lr.ph1672 ], [ %4240, %__nv_exp2f.exit1513 ] + %.pn2584 = phi float [ %2692, %.lr.ph1672 ], [ %4241, %__nv_exp2f.exit1513 ] + %.pn2585 = phi float [ %2693, %.lr.ph1672 ], [ %4242, %__nv_exp2f.exit1513 ] + %.pn2586 = phi float [ %2694, %.lr.ph1672 ], [ %4243, %__nv_exp2f.exit1513 ] + %.pn2587 = phi float [ %2695, %.lr.ph1672 ], [ %4244, %__nv_exp2f.exit1513 ] + %.pn2588 = phi float [ %2696, %.lr.ph1672 ], [ %4245, %__nv_exp2f.exit1513 ] + %.pn2589 = phi float [ %2697, %.lr.ph1672 ], [ %4246, %__nv_exp2f.exit1513 ] + %.pn2590 = phi float [ %2698, %.lr.ph1672 ], [ %4247, %__nv_exp2f.exit1513 ] + %.pn2591 = phi float [ %2699, %.lr.ph1672 ], [ %4248, %__nv_exp2f.exit1513 ] + %.pn2592 = phi float [ %2700, %.lr.ph1672 ], [ %4249, %__nv_exp2f.exit1513 ] + %2705 = phi i32 [ 0, %.lr.ph1672 ], [ %4253, %__nv_exp2f.exit1513 ] + %2706 = phi <16 x i32> [ %2634, %.lr.ph1672 ], [ %4252, %__nv_exp2f.exit1513 ] + %2707 = icmp slt i32 %2705, %2635, !dbg !166 + %2708 = icmp slt i32 %2705, %2636, !dbg !166 + %2709 = add i32 %2703, 1, !dbg !166 + %2710 = icmp sgt i32 %2709, 2, !dbg !166 + %2711 = select i1 %2710, i32 0, i32 %2709, !dbg !166 + %2712 = extractelement <16 x i32> %2706, i64 15, !dbg !167 + %2713 = icmp slt i32 %2712, %18, !dbg !167 + %2714 = extractelement <16 x i32> %2706, i64 14, !dbg !167 + %2715 = icmp slt i32 %2714, %18, !dbg !167 + %2716 = extractelement <16 x i32> %2706, i64 13, !dbg !167 + %2717 = icmp slt i32 %2716, %18, !dbg !167 + %2718 = extractelement <16 x i32> %2706, i64 12, !dbg !167 + %2719 = icmp slt i32 %2718, %18, !dbg !167 + %2720 = extractelement <16 x i32> %2706, i64 11, !dbg !167 + %2721 = icmp slt i32 %2720, %18, !dbg !167 + %2722 = extractelement <16 x i32> %2706, i64 10, !dbg !167 + %2723 = icmp slt i32 %2722, %18, !dbg !167 + %2724 = extractelement <16 x i32> %2706, i64 9, !dbg !167 + %2725 = icmp slt i32 %2724, %18, !dbg !167 + %2726 = extractelement <16 x i32> %2706, i64 8, !dbg !167 + %2727 = icmp slt i32 %2726, %18, !dbg !167 + %2728 = extractelement <16 x i32> %2706, i64 7, !dbg !167 + %2729 = icmp slt i32 %2728, %18, !dbg !167 + %2730 = extractelement <16 x i32> %2706, i64 6, !dbg !167 + %2731 = icmp slt i32 %2730, %18, !dbg !167 + %2732 = extractelement <16 x i32> %2706, i64 5, !dbg !167 + %2733 = icmp slt i32 %2732, %18, !dbg !167 + %2734 = extractelement <16 x i32> %2706, i64 4, !dbg !167 + %2735 = icmp slt i32 %2734, %18, !dbg !167 + %2736 = extractelement <16 x i32> %2706, i64 3, !dbg !167 + %2737 = icmp slt i32 %2736, %18, !dbg !167 + %2738 = extractelement <16 x i32> %2706, i64 2, !dbg !167 + %2739 = icmp slt i32 %2738, %18, !dbg !167 + %2740 = extractelement <16 x i32> %2706, i64 1, !dbg !167 + %2741 = icmp slt i32 %2740, %18, !dbg !167 + %2742 = extractelement <16 x i32> %2706, i64 0, !dbg !167 + %2743 = icmp slt i32 %2742, %18, !dbg !167 + tail call void @llvm.nvvm.cp.async.wait.group(i32 2), !dbg !168 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !168 + %2744 = shl i32 %2711, 13, !dbg !168 + %2745 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %2744, !dbg !168 + %2746 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %45, i32 0, i32 31), !dbg !172 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !172 + %2747 = shl i32 %2746, 11, !dbg !172 + %2748 = and i32 %2747, 8192, !dbg !172 + %2749 = add i32 %2748, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2750 = lshr exact i32 %2749, 4, !dbg !172 + %2751 = and i32 %2750, 16383, !dbg !172 + %2752 = zext nneg i32 %2751 to i64, !dbg !172 + %2753 = or disjoint i64 %2752, 4611686293372403712, !dbg !172 + %2754 = ptrtoint ptr addrspace(3) %2745 to i32, !dbg !172 + %2755 = lshr exact i32 %2754, 4, !dbg !172 + %2756 = and i32 %2755, 16383, !dbg !172 + %2757 = zext nneg i32 %2756 to i64, !dbg !172 + %2758 = or disjoint i64 %2757, 4611686293338849280, !dbg !172 + %2759 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %2753, i64 %2758) #3, !dbg !172 + %2760 = or disjoint i32 %2748, 32, !dbg !172 + %2761 = add i32 %2760, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2762 = lshr exact i32 %2761, 4, !dbg !172 + %2763 = and i32 %2762, 16383, !dbg !172 + %2764 = zext nneg i32 %2763 to i64, !dbg !172 + %2765 = or disjoint i64 %2764, 4611686293372403712, !dbg !172 + %2766 = add i32 %2754, 32, !dbg !172 + %2767 = lshr exact i32 %2766, 4, !dbg !172 + %2768 = and i32 %2767, 16383, !dbg !172 + %2769 = zext nneg i32 %2768 to i64, !dbg !172 + %2770 = or disjoint i64 %2769, 4611686293338849280, !dbg !172 + %2771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 0, !dbg !172 + %2772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 1, !dbg !172 + %2773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 2, !dbg !172 + %2774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 3, !dbg !172 + %2775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 4, !dbg !172 + %2776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 5, !dbg !172 + %2777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 6, !dbg !172 + %2778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 7, !dbg !172 + %2779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 8, !dbg !172 + %2780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 9, !dbg !172 + %2781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 10, !dbg !172 + %2782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 11, !dbg !172 + %2783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 12, !dbg !172 + %2784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 13, !dbg !172 + %2785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 14, !dbg !172 + %2786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 15, !dbg !172 + %2787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 16, !dbg !172 + %2788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 17, !dbg !172 + %2789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 18, !dbg !172 + %2790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 19, !dbg !172 + %2791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 20, !dbg !172 + %2792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 21, !dbg !172 + %2793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 22, !dbg !172 + %2794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 23, !dbg !172 + %2795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 24, !dbg !172 + %2796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 25, !dbg !172 + %2797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 26, !dbg !172 + %2798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 27, !dbg !172 + %2799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 28, !dbg !172 + %2800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 29, !dbg !172 + %2801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 30, !dbg !172 + %2802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2759, 31, !dbg !172 + %2803 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2771, float %2772, float %2773, float %2774, float %2775, float %2776, float %2777, float %2778, float %2779, float %2780, float %2781, float %2782, float %2783, float %2784, float %2785, float %2786, float %2787, float %2788, float %2789, float %2790, float %2791, float %2792, float %2793, float %2794, float %2795, float %2796, float %2797, float %2798, float %2799, float %2800, float %2801, float %2802, i64 %2765, i64 %2770, i1 true) #3, !dbg !172 + %2804 = or disjoint i32 %2748, 64, !dbg !172 + %2805 = add i32 %2804, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2806 = lshr exact i32 %2805, 4, !dbg !172 + %2807 = and i32 %2806, 16383, !dbg !172 + %2808 = zext nneg i32 %2807 to i64, !dbg !172 + %2809 = or disjoint i64 %2808, 4611686293372403712, !dbg !172 + %2810 = add i32 %2754, 64, !dbg !172 + %2811 = lshr exact i32 %2810, 4, !dbg !172 + %2812 = and i32 %2811, 16383, !dbg !172 + %2813 = zext nneg i32 %2812 to i64, !dbg !172 + %2814 = or disjoint i64 %2813, 4611686293338849280, !dbg !172 + %2815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 0, !dbg !172 + %2816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 1, !dbg !172 + %2817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 2, !dbg !172 + %2818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 3, !dbg !172 + %2819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 4, !dbg !172 + %2820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 5, !dbg !172 + %2821 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 6, !dbg !172 + %2822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 7, !dbg !172 + %2823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 8, !dbg !172 + %2824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 9, !dbg !172 + %2825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 10, !dbg !172 + %2826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 11, !dbg !172 + %2827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 12, !dbg !172 + %2828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 13, !dbg !172 + %2829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 14, !dbg !172 + %2830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 15, !dbg !172 + %2831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 16, !dbg !172 + %2832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 17, !dbg !172 + %2833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 18, !dbg !172 + %2834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 19, !dbg !172 + %2835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 20, !dbg !172 + %2836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 21, !dbg !172 + %2837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 22, !dbg !172 + %2838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 23, !dbg !172 + %2839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 24, !dbg !172 + %2840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 25, !dbg !172 + %2841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 26, !dbg !172 + %2842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 27, !dbg !172 + %2843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 28, !dbg !172 + %2844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 29, !dbg !172 + %2845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 30, !dbg !172 + %2846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2803, 31, !dbg !172 + %2847 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2815, float %2816, float %2817, float %2818, float %2819, float %2820, float %2821, float %2822, float %2823, float %2824, float %2825, float %2826, float %2827, float %2828, float %2829, float %2830, float %2831, float %2832, float %2833, float %2834, float %2835, float %2836, float %2837, float %2838, float %2839, float %2840, float %2841, float %2842, float %2843, float %2844, float %2845, float %2846, i64 %2809, i64 %2814, i1 true) #3, !dbg !172 + %2848 = or disjoint i32 %2748, 96, !dbg !172 + %2849 = add i32 %2848, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2850 = lshr exact i32 %2849, 4, !dbg !172 + %2851 = and i32 %2850, 16383, !dbg !172 + %2852 = zext nneg i32 %2851 to i64, !dbg !172 + %2853 = or disjoint i64 %2852, 4611686293372403712, !dbg !172 + %2854 = add i32 %2754, 96, !dbg !172 + %2855 = lshr exact i32 %2854, 4, !dbg !172 + %2856 = and i32 %2855, 16383, !dbg !172 + %2857 = zext nneg i32 %2856 to i64, !dbg !172 + %2858 = or disjoint i64 %2857, 4611686293338849280, !dbg !172 + %2859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 0, !dbg !172 + %2860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 1, !dbg !172 + %2861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 2, !dbg !172 + %2862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 3, !dbg !172 + %2863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 4, !dbg !172 + %2864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 5, !dbg !172 + %2865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 6, !dbg !172 + %2866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 7, !dbg !172 + %2867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 8, !dbg !172 + %2868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 9, !dbg !172 + %2869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 10, !dbg !172 + %2870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 11, !dbg !172 + %2871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 12, !dbg !172 + %2872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 13, !dbg !172 + %2873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 14, !dbg !172 + %2874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 15, !dbg !172 + %2875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 16, !dbg !172 + %2876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 17, !dbg !172 + %2877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 18, !dbg !172 + %2878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 19, !dbg !172 + %2879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 20, !dbg !172 + %2880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 21, !dbg !172 + %2881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 22, !dbg !172 + %2882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 23, !dbg !172 + %2883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 24, !dbg !172 + %2884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 25, !dbg !172 + %2885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 26, !dbg !172 + %2886 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 27, !dbg !172 + %2887 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 28, !dbg !172 + %2888 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 29, !dbg !172 + %2889 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 30, !dbg !172 + %2890 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2847, 31, !dbg !172 + %2891 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2859, float %2860, float %2861, float %2862, float %2863, float %2864, float %2865, float %2866, float %2867, float %2868, float %2869, float %2870, float %2871, float %2872, float %2873, float %2874, float %2875, float %2876, float %2877, float %2878, float %2879, float %2880, float %2881, float %2882, float %2883, float %2884, float %2885, float %2886, float %2887, float %2888, float %2889, float %2890, i64 %2853, i64 %2858, i1 true) #3, !dbg !172 + %2892 = or disjoint i32 %2748, 16384, !dbg !172 + %2893 = add i32 %2892, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2894 = lshr exact i32 %2893, 4, !dbg !172 + %2895 = and i32 %2894, 16383, !dbg !172 + %2896 = zext nneg i32 %2895 to i64, !dbg !172 + %2897 = or disjoint i64 %2896, 4611686293372403712, !dbg !172 + %2898 = add i32 %2754, 8192, !dbg !172 + %2899 = lshr exact i32 %2898, 4, !dbg !172 + %2900 = and i32 %2899, 16383, !dbg !172 + %2901 = zext nneg i32 %2900 to i64, !dbg !172 + %2902 = or disjoint i64 %2901, 4611686293338849280, !dbg !172 + %2903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 0, !dbg !172 + %2904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 1, !dbg !172 + %2905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 2, !dbg !172 + %2906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 3, !dbg !172 + %2907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 4, !dbg !172 + %2908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 5, !dbg !172 + %2909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 6, !dbg !172 + %2910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 7, !dbg !172 + %2911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 8, !dbg !172 + %2912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 9, !dbg !172 + %2913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 10, !dbg !172 + %2914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 11, !dbg !172 + %2915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 12, !dbg !172 + %2916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 13, !dbg !172 + %2917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 14, !dbg !172 + %2918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 15, !dbg !172 + %2919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 16, !dbg !172 + %2920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 17, !dbg !172 + %2921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 18, !dbg !172 + %2922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 19, !dbg !172 + %2923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 20, !dbg !172 + %2924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 21, !dbg !172 + %2925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 22, !dbg !172 + %2926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 23, !dbg !172 + %2927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 24, !dbg !172 + %2928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 25, !dbg !172 + %2929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 26, !dbg !172 + %2930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 27, !dbg !172 + %2931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 28, !dbg !172 + %2932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 29, !dbg !172 + %2933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 30, !dbg !172 + %2934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2891, 31, !dbg !172 + %2935 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2903, float %2904, float %2905, float %2906, float %2907, float %2908, float %2909, float %2910, float %2911, float %2912, float %2913, float %2914, float %2915, float %2916, float %2917, float %2918, float %2919, float %2920, float %2921, float %2922, float %2923, float %2924, float %2925, float %2926, float %2927, float %2928, float %2929, float %2930, float %2931, float %2932, float %2933, float %2934, i64 %2897, i64 %2902, i1 true) #3, !dbg !172 + %2936 = or disjoint i32 %2748, 16416, !dbg !172 + %2937 = add i32 %2936, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2938 = lshr exact i32 %2937, 4, !dbg !172 + %2939 = and i32 %2938, 16383, !dbg !172 + %2940 = zext nneg i32 %2939 to i64, !dbg !172 + %2941 = or disjoint i64 %2940, 4611686293372403712, !dbg !172 + %2942 = add i32 %2754, 8224, !dbg !172 + %2943 = lshr exact i32 %2942, 4, !dbg !172 + %2944 = and i32 %2943, 16383, !dbg !172 + %2945 = zext nneg i32 %2944 to i64, !dbg !172 + %2946 = or disjoint i64 %2945, 4611686293338849280, !dbg !172 + %2947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 0, !dbg !172 + %2948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 1, !dbg !172 + %2949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 2, !dbg !172 + %2950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 3, !dbg !172 + %2951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 4, !dbg !172 + %2952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 5, !dbg !172 + %2953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 6, !dbg !172 + %2954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 7, !dbg !172 + %2955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 8, !dbg !172 + %2956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 9, !dbg !172 + %2957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 10, !dbg !172 + %2958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 11, !dbg !172 + %2959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 12, !dbg !172 + %2960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 13, !dbg !172 + %2961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 14, !dbg !172 + %2962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 15, !dbg !172 + %2963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 16, !dbg !172 + %2964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 17, !dbg !172 + %2965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 18, !dbg !172 + %2966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 19, !dbg !172 + %2967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 20, !dbg !172 + %2968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 21, !dbg !172 + %2969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 22, !dbg !172 + %2970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 23, !dbg !172 + %2971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 24, !dbg !172 + %2972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 25, !dbg !172 + %2973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 26, !dbg !172 + %2974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 27, !dbg !172 + %2975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 28, !dbg !172 + %2976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 29, !dbg !172 + %2977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 30, !dbg !172 + %2978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2935, 31, !dbg !172 + %2979 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2947, float %2948, float %2949, float %2950, float %2951, float %2952, float %2953, float %2954, float %2955, float %2956, float %2957, float %2958, float %2959, float %2960, float %2961, float %2962, float %2963, float %2964, float %2965, float %2966, float %2967, float %2968, float %2969, float %2970, float %2971, float %2972, float %2973, float %2974, float %2975, float %2976, float %2977, float %2978, i64 %2941, i64 %2946, i1 true) #3, !dbg !172 + %2980 = or disjoint i32 %2748, 16448, !dbg !172 + %2981 = add i32 %2980, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %2982 = lshr exact i32 %2981, 4, !dbg !172 + %2983 = and i32 %2982, 16383, !dbg !172 + %2984 = zext nneg i32 %2983 to i64, !dbg !172 + %2985 = or disjoint i64 %2984, 4611686293372403712, !dbg !172 + %2986 = add i32 %2754, 8256, !dbg !172 + %2987 = lshr exact i32 %2986, 4, !dbg !172 + %2988 = and i32 %2987, 16383, !dbg !172 + %2989 = zext nneg i32 %2988 to i64, !dbg !172 + %2990 = or disjoint i64 %2989, 4611686293338849280, !dbg !172 + %2991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 0, !dbg !172 + %2992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 1, !dbg !172 + %2993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 2, !dbg !172 + %2994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 3, !dbg !172 + %2995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 4, !dbg !172 + %2996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 5, !dbg !172 + %2997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 6, !dbg !172 + %2998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 7, !dbg !172 + %2999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 8, !dbg !172 + %3000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 9, !dbg !172 + %3001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 10, !dbg !172 + %3002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 11, !dbg !172 + %3003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 12, !dbg !172 + %3004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 13, !dbg !172 + %3005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 14, !dbg !172 + %3006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 15, !dbg !172 + %3007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 16, !dbg !172 + %3008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 17, !dbg !172 + %3009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 18, !dbg !172 + %3010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 19, !dbg !172 + %3011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 20, !dbg !172 + %3012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 21, !dbg !172 + %3013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 22, !dbg !172 + %3014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 23, !dbg !172 + %3015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 24, !dbg !172 + %3016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 25, !dbg !172 + %3017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 26, !dbg !172 + %3018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 27, !dbg !172 + %3019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 28, !dbg !172 + %3020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 29, !dbg !172 + %3021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 30, !dbg !172 + %3022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %2979, 31, !dbg !172 + %3023 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %2991, float %2992, float %2993, float %2994, float %2995, float %2996, float %2997, float %2998, float %2999, float %3000, float %3001, float %3002, float %3003, float %3004, float %3005, float %3006, float %3007, float %3008, float %3009, float %3010, float %3011, float %3012, float %3013, float %3014, float %3015, float %3016, float %3017, float %3018, float %3019, float %3020, float %3021, float %3022, i64 %2985, i64 %2990, i1 true) #3, !dbg !172 + %3024 = or disjoint i32 %2748, 16480, !dbg !172 + %3025 = add i32 %3024, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304) to i32), !dbg !172 + %3026 = lshr exact i32 %3025, 4, !dbg !172 + %3027 = and i32 %3026, 16383, !dbg !172 + %3028 = zext nneg i32 %3027 to i64, !dbg !172 + %3029 = or disjoint i64 %3028, 4611686293372403712, !dbg !172 + %3030 = add i32 %2754, 8288, !dbg !172 + %3031 = lshr exact i32 %3030, 4, !dbg !172 + %3032 = and i32 %3031, 16383, !dbg !172 + %3033 = zext nneg i32 %3032 to i64, !dbg !172 + %3034 = or disjoint i64 %3033, 4611686293338849280, !dbg !172 + %3035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 0, !dbg !172 + %3036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 1, !dbg !172 + %3037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 2, !dbg !172 + %3038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 3, !dbg !172 + %3039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 4, !dbg !172 + %3040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 5, !dbg !172 + %3041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 6, !dbg !172 + %3042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 7, !dbg !172 + %3043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 8, !dbg !172 + %3044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 9, !dbg !172 + %3045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 10, !dbg !172 + %3046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 11, !dbg !172 + %3047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 12, !dbg !172 + %3048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 13, !dbg !172 + %3049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 14, !dbg !172 + %3050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 15, !dbg !172 + %3051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 16, !dbg !172 + %3052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 17, !dbg !172 + %3053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 18, !dbg !172 + %3054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 19, !dbg !172 + %3055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 20, !dbg !172 + %3056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 21, !dbg !172 + %3057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 22, !dbg !172 + %3058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 23, !dbg !172 + %3059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 24, !dbg !172 + %3060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 25, !dbg !172 + %3061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 26, !dbg !172 + %3062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 27, !dbg !172 + %3063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 28, !dbg !172 + %3064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 29, !dbg !172 + %3065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 30, !dbg !172 + %3066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3023, 31, !dbg !172 + %3067 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3035, float %3036, float %3037, float %3038, float %3039, float %3040, float %3041, float %3042, float %3043, float %3044, float %3045, float %3046, float %3047, float %3048, float %3049, float %3050, float %3051, float %3052, float %3053, float %3054, float %3055, float %3056, float %3057, float %3058, float %3059, float %3060, float %3061, float %3062, float %3063, float %3064, float %3065, float %3066, i64 %3029, i64 %3034, i1 true) #3, !dbg !172 + %3068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 0, !dbg !172 + %3069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 1, !dbg !172 + %3070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 2, !dbg !172 + %3071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 3, !dbg !172 + %3072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 4, !dbg !172 + %3073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 5, !dbg !172 + %3074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 6, !dbg !172 + %3075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 7, !dbg !172 + %3076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 8, !dbg !172 + %3077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 9, !dbg !172 + %3078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 10, !dbg !172 + %3079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 11, !dbg !172 + %3080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 12, !dbg !172 + %3081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 13, !dbg !172 + %3082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 14, !dbg !172 + %3083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 15, !dbg !172 + %3084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 16, !dbg !172 + %3085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 17, !dbg !172 + %3086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 18, !dbg !172 + %3087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 19, !dbg !172 + %3088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 20, !dbg !172 + %3089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 21, !dbg !172 + %3090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 22, !dbg !172 + %3091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 23, !dbg !172 + %3092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 24, !dbg !172 + %3093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 25, !dbg !172 + %3094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 26, !dbg !172 + %3095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 27, !dbg !172 + %3096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 28, !dbg !172 + %3097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 29, !dbg !172 + %3098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 30, !dbg !172 + %3099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3067, 31, !dbg !172 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !172 + %3100 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %3068, float %3069, float %3070, float %3071, float %3072, float %3073, float %3074, float %3075, float %3076, float %3077, float %3078, float %3079, float %3080, float %3081, float %3082, float %3083, float %3084, float %3085, float %3086, float %3087, float %3088, float %3089, float %3090, float %3091, float %3092, float %3093, float %3094, float %3095, float %3096, float %3097, float %3098, float %3099, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 0, i32 0, ptr addrspace(3) %2745, i32 0, i32 0) #3, !dbg !172 + %3101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 0, !dbg !172 + %3102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 1, !dbg !172 + %3103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 2, !dbg !172 + %3104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 3, !dbg !172 + %3105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 4, !dbg !172 + %3106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 5, !dbg !172 + %3107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 6, !dbg !172 + %3108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 7, !dbg !172 + %3109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 8, !dbg !172 + %3110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 9, !dbg !172 + %3111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 10, !dbg !172 + %3112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 11, !dbg !172 + %3113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 12, !dbg !172 + %3114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 13, !dbg !172 + %3115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 14, !dbg !172 + %3116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 15, !dbg !172 + %3117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 16, !dbg !172 + %3118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 17, !dbg !172 + %3119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 18, !dbg !172 + %3120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 19, !dbg !172 + %3121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 20, !dbg !172 + %3122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 21, !dbg !172 + %3123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 22, !dbg !172 + %3124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 23, !dbg !172 + %3125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 24, !dbg !172 + %3126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 25, !dbg !172 + %3127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 26, !dbg !172 + %3128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 27, !dbg !172 + %3129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 28, !dbg !172 + %3130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 29, !dbg !172 + %3131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 30, !dbg !172 + %3132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3100, 31, !dbg !172 + %3133 = fmul float %3101, 0x3FB6A09E60000000, !dbg !173 + %3134 = fmul float %3102, 0x3FB6A09E60000000, !dbg !173 + %3135 = fmul float %3103, 0x3FB6A09E60000000, !dbg !173 + %3136 = fmul float %3104, 0x3FB6A09E60000000, !dbg !173 + %3137 = fmul float %3105, 0x3FB6A09E60000000, !dbg !173 + %3138 = fmul float %3106, 0x3FB6A09E60000000, !dbg !173 + %3139 = fmul float %3107, 0x3FB6A09E60000000, !dbg !173 + %3140 = fmul float %3108, 0x3FB6A09E60000000, !dbg !173 + %3141 = fmul float %3109, 0x3FB6A09E60000000, !dbg !173 + %3142 = fmul float %3110, 0x3FB6A09E60000000, !dbg !173 + %3143 = fmul float %3111, 0x3FB6A09E60000000, !dbg !173 + %3144 = fmul float %3112, 0x3FB6A09E60000000, !dbg !173 + %3145 = fmul float %3113, 0x3FB6A09E60000000, !dbg !173 + %3146 = fmul float %3114, 0x3FB6A09E60000000, !dbg !173 + %3147 = fmul float %3115, 0x3FB6A09E60000000, !dbg !173 + %3148 = fmul float %3116, 0x3FB6A09E60000000, !dbg !173 + %3149 = fmul float %3117, 0x3FB6A09E60000000, !dbg !173 + %3150 = fmul float %3118, 0x3FB6A09E60000000, !dbg !173 + %3151 = fmul float %3119, 0x3FB6A09E60000000, !dbg !173 + %3152 = fmul float %3120, 0x3FB6A09E60000000, !dbg !173 + %3153 = fmul float %3121, 0x3FB6A09E60000000, !dbg !173 + %3154 = fmul float %3122, 0x3FB6A09E60000000, !dbg !173 + %3155 = fmul float %3123, 0x3FB6A09E60000000, !dbg !173 + %3156 = fmul float %3124, 0x3FB6A09E60000000, !dbg !173 + %3157 = fmul float %3125, 0x3FB6A09E60000000, !dbg !173 + %3158 = fmul float %3126, 0x3FB6A09E60000000, !dbg !173 + %3159 = fmul float %3127, 0x3FB6A09E60000000, !dbg !173 + %3160 = fmul float %3128, 0x3FB6A09E60000000, !dbg !173 + %3161 = fmul float %3129, 0x3FB6A09E60000000, !dbg !173 + %3162 = fmul float %3130, 0x3FB6A09E60000000, !dbg !173 + %3163 = fmul float %3131, 0x3FB6A09E60000000, !dbg !173 + %3164 = fmul float %3132, 0x3FB6A09E60000000, !dbg !173 + %3165 = fmul float %3133, 0x3FF7154760000000, !dbg !174 + %3166 = select i1 %2713, float %3165, float 0xFFF0000000000000, !dbg !175 + %3167 = fmul float %3134, 0x3FF7154760000000, !dbg !174 + %3168 = select i1 %2715, float %3167, float 0xFFF0000000000000, !dbg !175 + %3169 = fmul float %3135, 0x3FF7154760000000, !dbg !174 + %3170 = select i1 %2713, float %3169, float 0xFFF0000000000000, !dbg !175 + %3171 = fmul float %3136, 0x3FF7154760000000, !dbg !174 + %3172 = select i1 %2715, float %3171, float 0xFFF0000000000000, !dbg !175 + %3173 = fmul float %3137, 0x3FF7154760000000, !dbg !174 + %3174 = select i1 %2717, float %3173, float 0xFFF0000000000000, !dbg !175 + %3175 = fmul float %3138, 0x3FF7154760000000, !dbg !174 + %3176 = select i1 %2719, float %3175, float 0xFFF0000000000000, !dbg !175 + %3177 = fmul float %3139, 0x3FF7154760000000, !dbg !174 + %3178 = select i1 %2717, float %3177, float 0xFFF0000000000000, !dbg !175 + %3179 = fmul float %3140, 0x3FF7154760000000, !dbg !174 + %3180 = select i1 %2719, float %3179, float 0xFFF0000000000000, !dbg !175 + %3181 = fmul float %3141, 0x3FF7154760000000, !dbg !174 + %3182 = select i1 %2721, float %3181, float 0xFFF0000000000000, !dbg !175 + %3183 = fmul float %3142, 0x3FF7154760000000, !dbg !174 + %3184 = select i1 %2723, float %3183, float 0xFFF0000000000000, !dbg !175 + %3185 = fmul float %3143, 0x3FF7154760000000, !dbg !174 + %3186 = select i1 %2721, float %3185, float 0xFFF0000000000000, !dbg !175 + %3187 = fmul float %3144, 0x3FF7154760000000, !dbg !174 + %3188 = select i1 %2723, float %3187, float 0xFFF0000000000000, !dbg !175 + %3189 = fmul float %3145, 0x3FF7154760000000, !dbg !174 + %3190 = select i1 %2725, float %3189, float 0xFFF0000000000000, !dbg !175 + %3191 = fmul float %3146, 0x3FF7154760000000, !dbg !174 + %3192 = select i1 %2727, float %3191, float 0xFFF0000000000000, !dbg !175 + %3193 = fmul float %3147, 0x3FF7154760000000, !dbg !174 + %3194 = select i1 %2725, float %3193, float 0xFFF0000000000000, !dbg !175 + %3195 = fmul float %3148, 0x3FF7154760000000, !dbg !174 + %3196 = select i1 %2727, float %3195, float 0xFFF0000000000000, !dbg !175 + %3197 = fmul float %3149, 0x3FF7154760000000, !dbg !174 + %3198 = select i1 %2729, float %3197, float 0xFFF0000000000000, !dbg !175 + %3199 = fmul float %3150, 0x3FF7154760000000, !dbg !174 + %3200 = select i1 %2731, float %3199, float 0xFFF0000000000000, !dbg !175 + %3201 = fmul float %3151, 0x3FF7154760000000, !dbg !174 + %3202 = select i1 %2729, float %3201, float 0xFFF0000000000000, !dbg !175 + %3203 = fmul float %3152, 0x3FF7154760000000, !dbg !174 + %3204 = select i1 %2731, float %3203, float 0xFFF0000000000000, !dbg !175 + %3205 = fmul float %3153, 0x3FF7154760000000, !dbg !174 + %3206 = select i1 %2733, float %3205, float 0xFFF0000000000000, !dbg !175 + %3207 = fmul float %3154, 0x3FF7154760000000, !dbg !174 + %3208 = select i1 %2735, float %3207, float 0xFFF0000000000000, !dbg !175 + %3209 = fmul float %3155, 0x3FF7154760000000, !dbg !174 + %3210 = select i1 %2733, float %3209, float 0xFFF0000000000000, !dbg !175 + %3211 = fmul float %3156, 0x3FF7154760000000, !dbg !174 + %3212 = select i1 %2735, float %3211, float 0xFFF0000000000000, !dbg !175 + %3213 = fmul float %3157, 0x3FF7154760000000, !dbg !174 + %3214 = select i1 %2737, float %3213, float 0xFFF0000000000000, !dbg !175 + %3215 = fmul float %3158, 0x3FF7154760000000, !dbg !174 + %3216 = select i1 %2739, float %3215, float 0xFFF0000000000000, !dbg !175 + %3217 = fmul float %3159, 0x3FF7154760000000, !dbg !174 + %3218 = select i1 %2737, float %3217, float 0xFFF0000000000000, !dbg !175 + %3219 = fmul float %3160, 0x3FF7154760000000, !dbg !174 + %3220 = select i1 %2739, float %3219, float 0xFFF0000000000000, !dbg !175 + %3221 = fmul float %3161, 0x3FF7154760000000, !dbg !174 + %3222 = select i1 %2741, float %3221, float 0xFFF0000000000000, !dbg !175 + %3223 = fmul float %3162, 0x3FF7154760000000, !dbg !174 + %3224 = select i1 %2743, float %3223, float 0xFFF0000000000000, !dbg !175 + %3225 = fmul float %3163, 0x3FF7154760000000, !dbg !174 + %3226 = select i1 %2741, float %3225, float 0xFFF0000000000000, !dbg !175 + %3227 = fmul float %3164, 0x3FF7154760000000, !dbg !174 + %3228 = select i1 %2743, float %3227, float 0xFFF0000000000000, !dbg !175 + %3229 = fsub float %3166, %366, !dbg !176 + %3230 = fsub float %3168, %366, !dbg !176 + %3231 = fsub float %3170, %367, !dbg !176 + %3232 = fsub float %3172, %367, !dbg !176 + %3233 = fsub float %3174, %366, !dbg !176 + %3234 = fsub float %3176, %366, !dbg !176 + %3235 = fsub float %3178, %367, !dbg !176 + %3236 = fsub float %3180, %367, !dbg !176 + %3237 = fsub float %3182, %366, !dbg !176 + %3238 = fsub float %3184, %366, !dbg !176 + %3239 = fsub float %3186, %367, !dbg !176 + %3240 = fsub float %3188, %367, !dbg !176 + %3241 = fsub float %3190, %366, !dbg !176 + %3242 = fsub float %3192, %366, !dbg !176 + %3243 = fsub float %3194, %367, !dbg !176 + %3244 = fsub float %3196, %367, !dbg !176 + %3245 = fsub float %3198, %366, !dbg !176 + %3246 = fsub float %3200, %366, !dbg !176 + %3247 = fsub float %3202, %367, !dbg !176 + %3248 = fsub float %3204, %367, !dbg !176 + %3249 = fsub float %3206, %366, !dbg !176 + %3250 = fsub float %3208, %366, !dbg !176 + %3251 = fsub float %3210, %367, !dbg !176 + %3252 = fsub float %3212, %367, !dbg !176 + %3253 = fsub float %3214, %366, !dbg !176 + %3254 = fsub float %3216, %366, !dbg !176 + %3255 = fsub float %3218, %367, !dbg !176 + %3256 = fsub float %3220, %367, !dbg !176 + %3257 = fsub float %3222, %366, !dbg !176 + %3258 = fsub float %3224, %366, !dbg !176 + %3259 = fsub float %3226, %367, !dbg !176 + %3260 = fsub float %3228, %367, !dbg !176 + %3261 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1418 = icmp eq i32 %3261, 0, !dbg !177 + br i1 %.not.i1418, label %3264, label %3262, !dbg !177 + +3262: ; preds = %2701 + %3263 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3229) #3, !dbg !177 + br label %__nv_exp2f.exit1420, !dbg !177 + +3264: ; preds = %2701 + %3265 = tail call float @llvm.nvvm.ex2.approx.f(float %3229) #3, !dbg !177 + br label %__nv_exp2f.exit1420, !dbg !177 + +__nv_exp2f.exit1420: ; preds = %3262, %3264 + %.0.i1419 = phi float [ %3263, %3262 ], [ %3265, %3264 ], !dbg !177 + %3266 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1421 = icmp eq i32 %3266, 0, !dbg !177 + br i1 %.not.i1421, label %3269, label %3267, !dbg !177 + +3267: ; preds = %__nv_exp2f.exit1420 + %3268 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3230) #3, !dbg !177 + br label %__nv_exp2f.exit1423, !dbg !177 + +3269: ; preds = %__nv_exp2f.exit1420 + %3270 = tail call float @llvm.nvvm.ex2.approx.f(float %3230) #3, !dbg !177 + br label %__nv_exp2f.exit1423, !dbg !177 + +__nv_exp2f.exit1423: ; preds = %3267, %3269 + %.0.i1422 = phi float [ %3268, %3267 ], [ %3270, %3269 ], !dbg !177 + %3271 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1424 = icmp eq i32 %3271, 0, !dbg !177 + br i1 %.not.i1424, label %3274, label %3272, !dbg !177 + +3272: ; preds = %__nv_exp2f.exit1423 + %3273 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3231) #3, !dbg !177 + br label %__nv_exp2f.exit1426, !dbg !177 + +3274: ; preds = %__nv_exp2f.exit1423 + %3275 = tail call float @llvm.nvvm.ex2.approx.f(float %3231) #3, !dbg !177 + br label %__nv_exp2f.exit1426, !dbg !177 + +__nv_exp2f.exit1426: ; preds = %3272, %3274 + %.0.i1425 = phi float [ %3273, %3272 ], [ %3275, %3274 ], !dbg !177 + %3276 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1427 = icmp eq i32 %3276, 0, !dbg !177 + br i1 %.not.i1427, label %3279, label %3277, !dbg !177 + +3277: ; preds = %__nv_exp2f.exit1426 + %3278 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3232) #3, !dbg !177 + br label %__nv_exp2f.exit1429, !dbg !177 + +3279: ; preds = %__nv_exp2f.exit1426 + %3280 = tail call float @llvm.nvvm.ex2.approx.f(float %3232) #3, !dbg !177 + br label %__nv_exp2f.exit1429, !dbg !177 + +__nv_exp2f.exit1429: ; preds = %3277, %3279 + %.0.i1428 = phi float [ %3278, %3277 ], [ %3280, %3279 ], !dbg !177 + %3281 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1430 = icmp eq i32 %3281, 0, !dbg !177 + br i1 %.not.i1430, label %3284, label %3282, !dbg !177 + +3282: ; preds = %__nv_exp2f.exit1429 + %3283 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3233) #3, !dbg !177 + br label %__nv_exp2f.exit1432, !dbg !177 + +3284: ; preds = %__nv_exp2f.exit1429 + %3285 = tail call float @llvm.nvvm.ex2.approx.f(float %3233) #3, !dbg !177 + br label %__nv_exp2f.exit1432, !dbg !177 + +__nv_exp2f.exit1432: ; preds = %3282, %3284 + %.0.i1431 = phi float [ %3283, %3282 ], [ %3285, %3284 ], !dbg !177 + %3286 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1433 = icmp eq i32 %3286, 0, !dbg !177 + br i1 %.not.i1433, label %3289, label %3287, !dbg !177 + +3287: ; preds = %__nv_exp2f.exit1432 + %3288 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3234) #3, !dbg !177 + br label %__nv_exp2f.exit1435, !dbg !177 + +3289: ; preds = %__nv_exp2f.exit1432 + %3290 = tail call float @llvm.nvvm.ex2.approx.f(float %3234) #3, !dbg !177 + br label %__nv_exp2f.exit1435, !dbg !177 + +__nv_exp2f.exit1435: ; preds = %3287, %3289 + %.0.i1434 = phi float [ %3288, %3287 ], [ %3290, %3289 ], !dbg !177 + %3291 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1436 = icmp eq i32 %3291, 0, !dbg !177 + br i1 %.not.i1436, label %3294, label %3292, !dbg !177 + +3292: ; preds = %__nv_exp2f.exit1435 + %3293 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3235) #3, !dbg !177 + br label %__nv_exp2f.exit1438, !dbg !177 + +3294: ; preds = %__nv_exp2f.exit1435 + %3295 = tail call float @llvm.nvvm.ex2.approx.f(float %3235) #3, !dbg !177 + br label %__nv_exp2f.exit1438, !dbg !177 + +__nv_exp2f.exit1438: ; preds = %3292, %3294 + %.0.i1437 = phi float [ %3293, %3292 ], [ %3295, %3294 ], !dbg !177 + %3296 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1439 = icmp eq i32 %3296, 0, !dbg !177 + br i1 %.not.i1439, label %3299, label %3297, !dbg !177 + +3297: ; preds = %__nv_exp2f.exit1438 + %3298 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3236) #3, !dbg !177 + br label %__nv_exp2f.exit1441, !dbg !177 + +3299: ; preds = %__nv_exp2f.exit1438 + %3300 = tail call float @llvm.nvvm.ex2.approx.f(float %3236) #3, !dbg !177 + br label %__nv_exp2f.exit1441, !dbg !177 + +__nv_exp2f.exit1441: ; preds = %3297, %3299 + %.0.i1440 = phi float [ %3298, %3297 ], [ %3300, %3299 ], !dbg !177 + %3301 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1442 = icmp eq i32 %3301, 0, !dbg !177 + br i1 %.not.i1442, label %3304, label %3302, !dbg !177 + +3302: ; preds = %__nv_exp2f.exit1441 + %3303 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3237) #3, !dbg !177 + br label %__nv_exp2f.exit1444, !dbg !177 + +3304: ; preds = %__nv_exp2f.exit1441 + %3305 = tail call float @llvm.nvvm.ex2.approx.f(float %3237) #3, !dbg !177 + br label %__nv_exp2f.exit1444, !dbg !177 + +__nv_exp2f.exit1444: ; preds = %3302, %3304 + %.0.i1443 = phi float [ %3303, %3302 ], [ %3305, %3304 ], !dbg !177 + %3306 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1445 = icmp eq i32 %3306, 0, !dbg !177 + br i1 %.not.i1445, label %3309, label %3307, !dbg !177 + +3307: ; preds = %__nv_exp2f.exit1444 + %3308 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3238) #3, !dbg !177 + br label %__nv_exp2f.exit1447, !dbg !177 + +3309: ; preds = %__nv_exp2f.exit1444 + %3310 = tail call float @llvm.nvvm.ex2.approx.f(float %3238) #3, !dbg !177 + br label %__nv_exp2f.exit1447, !dbg !177 + +__nv_exp2f.exit1447: ; preds = %3307, %3309 + %.0.i1446 = phi float [ %3308, %3307 ], [ %3310, %3309 ], !dbg !177 + %3311 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1448 = icmp eq i32 %3311, 0, !dbg !177 + br i1 %.not.i1448, label %3314, label %3312, !dbg !177 + +3312: ; preds = %__nv_exp2f.exit1447 + %3313 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3239) #3, !dbg !177 + br label %__nv_exp2f.exit1450, !dbg !177 + +3314: ; preds = %__nv_exp2f.exit1447 + %3315 = tail call float @llvm.nvvm.ex2.approx.f(float %3239) #3, !dbg !177 + br label %__nv_exp2f.exit1450, !dbg !177 + +__nv_exp2f.exit1450: ; preds = %3312, %3314 + %.0.i1449 = phi float [ %3313, %3312 ], [ %3315, %3314 ], !dbg !177 + %3316 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1451 = icmp eq i32 %3316, 0, !dbg !177 + br i1 %.not.i1451, label %3319, label %3317, !dbg !177 + +3317: ; preds = %__nv_exp2f.exit1450 + %3318 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3240) #3, !dbg !177 + br label %__nv_exp2f.exit1453, !dbg !177 + +3319: ; preds = %__nv_exp2f.exit1450 + %3320 = tail call float @llvm.nvvm.ex2.approx.f(float %3240) #3, !dbg !177 + br label %__nv_exp2f.exit1453, !dbg !177 + +__nv_exp2f.exit1453: ; preds = %3317, %3319 + %.0.i1452 = phi float [ %3318, %3317 ], [ %3320, %3319 ], !dbg !177 + %3321 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1454 = icmp eq i32 %3321, 0, !dbg !177 + br i1 %.not.i1454, label %3324, label %3322, !dbg !177 + +3322: ; preds = %__nv_exp2f.exit1453 + %3323 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3241) #3, !dbg !177 + br label %__nv_exp2f.exit1456, !dbg !177 + +3324: ; preds = %__nv_exp2f.exit1453 + %3325 = tail call float @llvm.nvvm.ex2.approx.f(float %3241) #3, !dbg !177 + br label %__nv_exp2f.exit1456, !dbg !177 + +__nv_exp2f.exit1456: ; preds = %3322, %3324 + %.0.i1455 = phi float [ %3323, %3322 ], [ %3325, %3324 ], !dbg !177 + %3326 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1457 = icmp eq i32 %3326, 0, !dbg !177 + br i1 %.not.i1457, label %3329, label %3327, !dbg !177 + +3327: ; preds = %__nv_exp2f.exit1456 + %3328 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3242) #3, !dbg !177 + br label %__nv_exp2f.exit1459, !dbg !177 + +3329: ; preds = %__nv_exp2f.exit1456 + %3330 = tail call float @llvm.nvvm.ex2.approx.f(float %3242) #3, !dbg !177 + br label %__nv_exp2f.exit1459, !dbg !177 + +__nv_exp2f.exit1459: ; preds = %3327, %3329 + %.0.i1458 = phi float [ %3328, %3327 ], [ %3330, %3329 ], !dbg !177 + %3331 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1460 = icmp eq i32 %3331, 0, !dbg !177 + br i1 %.not.i1460, label %3334, label %3332, !dbg !177 + +3332: ; preds = %__nv_exp2f.exit1459 + %3333 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3243) #3, !dbg !177 + br label %__nv_exp2f.exit1462, !dbg !177 + +3334: ; preds = %__nv_exp2f.exit1459 + %3335 = tail call float @llvm.nvvm.ex2.approx.f(float %3243) #3, !dbg !177 + br label %__nv_exp2f.exit1462, !dbg !177 + +__nv_exp2f.exit1462: ; preds = %3332, %3334 + %.0.i1461 = phi float [ %3333, %3332 ], [ %3335, %3334 ], !dbg !177 + %3336 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1463 = icmp eq i32 %3336, 0, !dbg !177 + br i1 %.not.i1463, label %3339, label %3337, !dbg !177 + +3337: ; preds = %__nv_exp2f.exit1462 + %3338 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3244) #3, !dbg !177 + br label %__nv_exp2f.exit1465, !dbg !177 + +3339: ; preds = %__nv_exp2f.exit1462 + %3340 = tail call float @llvm.nvvm.ex2.approx.f(float %3244) #3, !dbg !177 + br label %__nv_exp2f.exit1465, !dbg !177 + +__nv_exp2f.exit1465: ; preds = %3337, %3339 + %.0.i1464 = phi float [ %3338, %3337 ], [ %3340, %3339 ], !dbg !177 + %3341 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1466 = icmp eq i32 %3341, 0, !dbg !177 + br i1 %.not.i1466, label %3344, label %3342, !dbg !177 + +3342: ; preds = %__nv_exp2f.exit1465 + %3343 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3245) #3, !dbg !177 + br label %__nv_exp2f.exit1468, !dbg !177 + +3344: ; preds = %__nv_exp2f.exit1465 + %3345 = tail call float @llvm.nvvm.ex2.approx.f(float %3245) #3, !dbg !177 + br label %__nv_exp2f.exit1468, !dbg !177 + +__nv_exp2f.exit1468: ; preds = %3342, %3344 + %.0.i1467 = phi float [ %3343, %3342 ], [ %3345, %3344 ], !dbg !177 + %3346 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1469 = icmp eq i32 %3346, 0, !dbg !177 + br i1 %.not.i1469, label %3349, label %3347, !dbg !177 + +3347: ; preds = %__nv_exp2f.exit1468 + %3348 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3246) #3, !dbg !177 + br label %__nv_exp2f.exit1471, !dbg !177 + +3349: ; preds = %__nv_exp2f.exit1468 + %3350 = tail call float @llvm.nvvm.ex2.approx.f(float %3246) #3, !dbg !177 + br label %__nv_exp2f.exit1471, !dbg !177 + +__nv_exp2f.exit1471: ; preds = %3347, %3349 + %.0.i1470 = phi float [ %3348, %3347 ], [ %3350, %3349 ], !dbg !177 + %3351 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1472 = icmp eq i32 %3351, 0, !dbg !177 + br i1 %.not.i1472, label %3354, label %3352, !dbg !177 + +3352: ; preds = %__nv_exp2f.exit1471 + %3353 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3247) #3, !dbg !177 + br label %__nv_exp2f.exit1474, !dbg !177 + +3354: ; preds = %__nv_exp2f.exit1471 + %3355 = tail call float @llvm.nvvm.ex2.approx.f(float %3247) #3, !dbg !177 + br label %__nv_exp2f.exit1474, !dbg !177 + +__nv_exp2f.exit1474: ; preds = %3352, %3354 + %.0.i1473 = phi float [ %3353, %3352 ], [ %3355, %3354 ], !dbg !177 + %3356 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1475 = icmp eq i32 %3356, 0, !dbg !177 + br i1 %.not.i1475, label %3359, label %3357, !dbg !177 + +3357: ; preds = %__nv_exp2f.exit1474 + %3358 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3248) #3, !dbg !177 + br label %__nv_exp2f.exit1477, !dbg !177 + +3359: ; preds = %__nv_exp2f.exit1474 + %3360 = tail call float @llvm.nvvm.ex2.approx.f(float %3248) #3, !dbg !177 + br label %__nv_exp2f.exit1477, !dbg !177 + +__nv_exp2f.exit1477: ; preds = %3357, %3359 + %.0.i1476 = phi float [ %3358, %3357 ], [ %3360, %3359 ], !dbg !177 + %3361 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1478 = icmp eq i32 %3361, 0, !dbg !177 + br i1 %.not.i1478, label %3364, label %3362, !dbg !177 + +3362: ; preds = %__nv_exp2f.exit1477 + %3363 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3249) #3, !dbg !177 + br label %__nv_exp2f.exit1480, !dbg !177 + +3364: ; preds = %__nv_exp2f.exit1477 + %3365 = tail call float @llvm.nvvm.ex2.approx.f(float %3249) #3, !dbg !177 + br label %__nv_exp2f.exit1480, !dbg !177 + +__nv_exp2f.exit1480: ; preds = %3362, %3364 + %.0.i1479 = phi float [ %3363, %3362 ], [ %3365, %3364 ], !dbg !177 + %3366 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1481 = icmp eq i32 %3366, 0, !dbg !177 + br i1 %.not.i1481, label %3369, label %3367, !dbg !177 + +3367: ; preds = %__nv_exp2f.exit1480 + %3368 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3250) #3, !dbg !177 + br label %__nv_exp2f.exit1483, !dbg !177 + +3369: ; preds = %__nv_exp2f.exit1480 + %3370 = tail call float @llvm.nvvm.ex2.approx.f(float %3250) #3, !dbg !177 + br label %__nv_exp2f.exit1483, !dbg !177 + +__nv_exp2f.exit1483: ; preds = %3367, %3369 + %.0.i1482 = phi float [ %3368, %3367 ], [ %3370, %3369 ], !dbg !177 + %3371 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1484 = icmp eq i32 %3371, 0, !dbg !177 + br i1 %.not.i1484, label %3374, label %3372, !dbg !177 + +3372: ; preds = %__nv_exp2f.exit1483 + %3373 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3251) #3, !dbg !177 + br label %__nv_exp2f.exit1486, !dbg !177 + +3374: ; preds = %__nv_exp2f.exit1483 + %3375 = tail call float @llvm.nvvm.ex2.approx.f(float %3251) #3, !dbg !177 + br label %__nv_exp2f.exit1486, !dbg !177 + +__nv_exp2f.exit1486: ; preds = %3372, %3374 + %.0.i1485 = phi float [ %3373, %3372 ], [ %3375, %3374 ], !dbg !177 + %3376 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1487 = icmp eq i32 %3376, 0, !dbg !177 + br i1 %.not.i1487, label %3379, label %3377, !dbg !177 + +3377: ; preds = %__nv_exp2f.exit1486 + %3378 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3252) #3, !dbg !177 + br label %__nv_exp2f.exit1489, !dbg !177 + +3379: ; preds = %__nv_exp2f.exit1486 + %3380 = tail call float @llvm.nvvm.ex2.approx.f(float %3252) #3, !dbg !177 + br label %__nv_exp2f.exit1489, !dbg !177 + +__nv_exp2f.exit1489: ; preds = %3377, %3379 + %.0.i1488 = phi float [ %3378, %3377 ], [ %3380, %3379 ], !dbg !177 + %3381 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1490 = icmp eq i32 %3381, 0, !dbg !177 + br i1 %.not.i1490, label %3384, label %3382, !dbg !177 + +3382: ; preds = %__nv_exp2f.exit1489 + %3383 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3253) #3, !dbg !177 + br label %__nv_exp2f.exit1492, !dbg !177 + +3384: ; preds = %__nv_exp2f.exit1489 + %3385 = tail call float @llvm.nvvm.ex2.approx.f(float %3253) #3, !dbg !177 + br label %__nv_exp2f.exit1492, !dbg !177 + +__nv_exp2f.exit1492: ; preds = %3382, %3384 + %.0.i1491 = phi float [ %3383, %3382 ], [ %3385, %3384 ], !dbg !177 + %3386 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1493 = icmp eq i32 %3386, 0, !dbg !177 + br i1 %.not.i1493, label %3389, label %3387, !dbg !177 + +3387: ; preds = %__nv_exp2f.exit1492 + %3388 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3254) #3, !dbg !177 + br label %__nv_exp2f.exit1495, !dbg !177 + +3389: ; preds = %__nv_exp2f.exit1492 + %3390 = tail call float @llvm.nvvm.ex2.approx.f(float %3254) #3, !dbg !177 + br label %__nv_exp2f.exit1495, !dbg !177 + +__nv_exp2f.exit1495: ; preds = %3387, %3389 + %.0.i1494 = phi float [ %3388, %3387 ], [ %3390, %3389 ], !dbg !177 + %3391 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1496 = icmp eq i32 %3391, 0, !dbg !177 + br i1 %.not.i1496, label %3394, label %3392, !dbg !177 + +3392: ; preds = %__nv_exp2f.exit1495 + %3393 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3255) #3, !dbg !177 + br label %__nv_exp2f.exit1498, !dbg !177 + +3394: ; preds = %__nv_exp2f.exit1495 + %3395 = tail call float @llvm.nvvm.ex2.approx.f(float %3255) #3, !dbg !177 + br label %__nv_exp2f.exit1498, !dbg !177 + +__nv_exp2f.exit1498: ; preds = %3392, %3394 + %.0.i1497 = phi float [ %3393, %3392 ], [ %3395, %3394 ], !dbg !177 + %3396 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1499 = icmp eq i32 %3396, 0, !dbg !177 + br i1 %.not.i1499, label %3399, label %3397, !dbg !177 + +3397: ; preds = %__nv_exp2f.exit1498 + %3398 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3256) #3, !dbg !177 + br label %__nv_exp2f.exit1501, !dbg !177 + +3399: ; preds = %__nv_exp2f.exit1498 + %3400 = tail call float @llvm.nvvm.ex2.approx.f(float %3256) #3, !dbg !177 + br label %__nv_exp2f.exit1501, !dbg !177 + +__nv_exp2f.exit1501: ; preds = %3397, %3399 + %.0.i1500 = phi float [ %3398, %3397 ], [ %3400, %3399 ], !dbg !177 + %3401 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1502 = icmp eq i32 %3401, 0, !dbg !177 + br i1 %.not.i1502, label %3404, label %3402, !dbg !177 + +3402: ; preds = %__nv_exp2f.exit1501 + %3403 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3257) #3, !dbg !177 + br label %__nv_exp2f.exit1504, !dbg !177 + +3404: ; preds = %__nv_exp2f.exit1501 + %3405 = tail call float @llvm.nvvm.ex2.approx.f(float %3257) #3, !dbg !177 + br label %__nv_exp2f.exit1504, !dbg !177 + +__nv_exp2f.exit1504: ; preds = %3402, %3404 + %.0.i1503 = phi float [ %3403, %3402 ], [ %3405, %3404 ], !dbg !177 + %3406 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1505 = icmp eq i32 %3406, 0, !dbg !177 + br i1 %.not.i1505, label %3409, label %3407, !dbg !177 + +3407: ; preds = %__nv_exp2f.exit1504 + %3408 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3258) #3, !dbg !177 + br label %__nv_exp2f.exit1507, !dbg !177 + +3409: ; preds = %__nv_exp2f.exit1504 + %3410 = tail call float @llvm.nvvm.ex2.approx.f(float %3258) #3, !dbg !177 + br label %__nv_exp2f.exit1507, !dbg !177 + +__nv_exp2f.exit1507: ; preds = %3407, %3409 + %.0.i1506 = phi float [ %3408, %3407 ], [ %3410, %3409 ], !dbg !177 + %3411 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1508 = icmp eq i32 %3411, 0, !dbg !177 + br i1 %.not.i1508, label %3414, label %3412, !dbg !177 + +3412: ; preds = %__nv_exp2f.exit1507 + %3413 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3259) #3, !dbg !177 + br label %__nv_exp2f.exit1510, !dbg !177 + +3414: ; preds = %__nv_exp2f.exit1507 + %3415 = tail call float @llvm.nvvm.ex2.approx.f(float %3259) #3, !dbg !177 + br label %__nv_exp2f.exit1510, !dbg !177 + +__nv_exp2f.exit1510: ; preds = %3412, %3414 + %.0.i1509 = phi float [ %3413, %3412 ], [ %3415, %3414 ], !dbg !177 + %3416 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !177 + %.not.i1511 = icmp eq i32 %3416, 0, !dbg !177 + br i1 %.not.i1511, label %3419, label %3417, !dbg !177 + +3417: ; preds = %__nv_exp2f.exit1510 + %3418 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %3260) #3, !dbg !177 + br label %__nv_exp2f.exit1513, !dbg !177 + +3419: ; preds = %__nv_exp2f.exit1510 + %3420 = tail call float @llvm.nvvm.ex2.approx.f(float %3260) #3, !dbg !177 + br label %__nv_exp2f.exit1513, !dbg !177 + +__nv_exp2f.exit1513: ; preds = %3417, %3419 + %.0.i1512 = phi float [ %3418, %3417 ], [ %3420, %3419 ], !dbg !177 + %3421 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %2744, !dbg !168 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !178 + %3422 = add i32 %2748, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3423 = lshr exact i32 %3422, 4, !dbg !178 + %3424 = and i32 %3423, 16383, !dbg !178 + %3425 = zext nneg i32 %3424 to i64, !dbg !178 + %3426 = or disjoint i64 %3425, 4611686293372403712, !dbg !178 + %3427 = ptrtoint ptr addrspace(3) %3421 to i32, !dbg !178 + %3428 = lshr exact i32 %3427, 4, !dbg !178 + %3429 = and i32 %3428, 16383, !dbg !178 + %3430 = zext nneg i32 %3429 to i64, !dbg !178 + %3431 = or disjoint i64 %3430, 4611686293338849280, !dbg !178 + %3432 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %3426, i64 %3431) #3, !dbg !178 + %3433 = add i32 %2760, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3434 = lshr exact i32 %3433, 4, !dbg !178 + %3435 = and i32 %3434, 16383, !dbg !178 + %3436 = zext nneg i32 %3435 to i64, !dbg !178 + %3437 = or disjoint i64 %3436, 4611686293372403712, !dbg !178 + %3438 = add i32 %3427, 32, !dbg !178 + %3439 = lshr exact i32 %3438, 4, !dbg !178 + %3440 = and i32 %3439, 16383, !dbg !178 + %3441 = zext nneg i32 %3440 to i64, !dbg !178 + %3442 = or disjoint i64 %3441, 4611686293338849280, !dbg !178 + %3443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 0, !dbg !178 + %3444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 1, !dbg !178 + %3445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 2, !dbg !178 + %3446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 3, !dbg !178 + %3447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 4, !dbg !178 + %3448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 5, !dbg !178 + %3449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 6, !dbg !178 + %3450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 7, !dbg !178 + %3451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 8, !dbg !178 + %3452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 9, !dbg !178 + %3453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 10, !dbg !178 + %3454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 11, !dbg !178 + %3455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 12, !dbg !178 + %3456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 13, !dbg !178 + %3457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 14, !dbg !178 + %3458 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 15, !dbg !178 + %3459 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 16, !dbg !178 + %3460 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 17, !dbg !178 + %3461 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 18, !dbg !178 + %3462 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 19, !dbg !178 + %3463 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 20, !dbg !178 + %3464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 21, !dbg !178 + %3465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 22, !dbg !178 + %3466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 23, !dbg !178 + %3467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 24, !dbg !178 + %3468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 25, !dbg !178 + %3469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 26, !dbg !178 + %3470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 27, !dbg !178 + %3471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 28, !dbg !178 + %3472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 29, !dbg !178 + %3473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 30, !dbg !178 + %3474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3432, 31, !dbg !178 + %3475 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3443, float %3444, float %3445, float %3446, float %3447, float %3448, float %3449, float %3450, float %3451, float %3452, float %3453, float %3454, float %3455, float %3456, float %3457, float %3458, float %3459, float %3460, float %3461, float %3462, float %3463, float %3464, float %3465, float %3466, float %3467, float %3468, float %3469, float %3470, float %3471, float %3472, float %3473, float %3474, i64 %3437, i64 %3442, i1 true) #3, !dbg !178 + %3476 = add i32 %2804, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3477 = lshr exact i32 %3476, 4, !dbg !178 + %3478 = and i32 %3477, 16383, !dbg !178 + %3479 = zext nneg i32 %3478 to i64, !dbg !178 + %3480 = or disjoint i64 %3479, 4611686293372403712, !dbg !178 + %3481 = add i32 %3427, 64, !dbg !178 + %3482 = lshr exact i32 %3481, 4, !dbg !178 + %3483 = and i32 %3482, 16383, !dbg !178 + %3484 = zext nneg i32 %3483 to i64, !dbg !178 + %3485 = or disjoint i64 %3484, 4611686293338849280, !dbg !178 + %3486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 0, !dbg !178 + %3487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 1, !dbg !178 + %3488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 2, !dbg !178 + %3489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 3, !dbg !178 + %3490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 4, !dbg !178 + %3491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 5, !dbg !178 + %3492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 6, !dbg !178 + %3493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 7, !dbg !178 + %3494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 8, !dbg !178 + %3495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 9, !dbg !178 + %3496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 10, !dbg !178 + %3497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 11, !dbg !178 + %3498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 12, !dbg !178 + %3499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 13, !dbg !178 + %3500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 14, !dbg !178 + %3501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 15, !dbg !178 + %3502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 16, !dbg !178 + %3503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 17, !dbg !178 + %3504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 18, !dbg !178 + %3505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 19, !dbg !178 + %3506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 20, !dbg !178 + %3507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 21, !dbg !178 + %3508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 22, !dbg !178 + %3509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 23, !dbg !178 + %3510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 24, !dbg !178 + %3511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 25, !dbg !178 + %3512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 26, !dbg !178 + %3513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 27, !dbg !178 + %3514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 28, !dbg !178 + %3515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 29, !dbg !178 + %3516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 30, !dbg !178 + %3517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3475, 31, !dbg !178 + %3518 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3486, float %3487, float %3488, float %3489, float %3490, float %3491, float %3492, float %3493, float %3494, float %3495, float %3496, float %3497, float %3498, float %3499, float %3500, float %3501, float %3502, float %3503, float %3504, float %3505, float %3506, float %3507, float %3508, float %3509, float %3510, float %3511, float %3512, float %3513, float %3514, float %3515, float %3516, float %3517, i64 %3480, i64 %3485, i1 true) #3, !dbg !178 + %3519 = add i32 %2848, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3520 = lshr exact i32 %3519, 4, !dbg !178 + %3521 = and i32 %3520, 16383, !dbg !178 + %3522 = zext nneg i32 %3521 to i64, !dbg !178 + %3523 = or disjoint i64 %3522, 4611686293372403712, !dbg !178 + %3524 = add i32 %3427, 96, !dbg !178 + %3525 = lshr exact i32 %3524, 4, !dbg !178 + %3526 = and i32 %3525, 16383, !dbg !178 + %3527 = zext nneg i32 %3526 to i64, !dbg !178 + %3528 = or disjoint i64 %3527, 4611686293338849280, !dbg !178 + %3529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 0, !dbg !178 + %3530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 1, !dbg !178 + %3531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 2, !dbg !178 + %3532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 3, !dbg !178 + %3533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 4, !dbg !178 + %3534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 5, !dbg !178 + %3535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 6, !dbg !178 + %3536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 7, !dbg !178 + %3537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 8, !dbg !178 + %3538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 9, !dbg !178 + %3539 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 10, !dbg !178 + %3540 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 11, !dbg !178 + %3541 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 12, !dbg !178 + %3542 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 13, !dbg !178 + %3543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 14, !dbg !178 + %3544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 15, !dbg !178 + %3545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 16, !dbg !178 + %3546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 17, !dbg !178 + %3547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 18, !dbg !178 + %3548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 19, !dbg !178 + %3549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 20, !dbg !178 + %3550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 21, !dbg !178 + %3551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 22, !dbg !178 + %3552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 23, !dbg !178 + %3553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 24, !dbg !178 + %3554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 25, !dbg !178 + %3555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 26, !dbg !178 + %3556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 27, !dbg !178 + %3557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 28, !dbg !178 + %3558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 29, !dbg !178 + %3559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 30, !dbg !178 + %3560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3518, 31, !dbg !178 + %3561 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3529, float %3530, float %3531, float %3532, float %3533, float %3534, float %3535, float %3536, float %3537, float %3538, float %3539, float %3540, float %3541, float %3542, float %3543, float %3544, float %3545, float %3546, float %3547, float %3548, float %3549, float %3550, float %3551, float %3552, float %3553, float %3554, float %3555, float %3556, float %3557, float %3558, float %3559, float %3560, i64 %3523, i64 %3528, i1 true) #3, !dbg !178 + %3562 = add i32 %2892, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3563 = lshr exact i32 %3562, 4, !dbg !178 + %3564 = and i32 %3563, 16383, !dbg !178 + %3565 = zext nneg i32 %3564 to i64, !dbg !178 + %3566 = or disjoint i64 %3565, 4611686293372403712, !dbg !178 + %3567 = add i32 %3427, 8192, !dbg !178 + %3568 = lshr exact i32 %3567, 4, !dbg !178 + %3569 = and i32 %3568, 16383, !dbg !178 + %3570 = zext nneg i32 %3569 to i64, !dbg !178 + %3571 = or disjoint i64 %3570, 4611686293338849280, !dbg !178 + %3572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 0, !dbg !178 + %3573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 1, !dbg !178 + %3574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 2, !dbg !178 + %3575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 3, !dbg !178 + %3576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 4, !dbg !178 + %3577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 5, !dbg !178 + %3578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 6, !dbg !178 + %3579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 7, !dbg !178 + %3580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 8, !dbg !178 + %3581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 9, !dbg !178 + %3582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 10, !dbg !178 + %3583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 11, !dbg !178 + %3584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 12, !dbg !178 + %3585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 13, !dbg !178 + %3586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 14, !dbg !178 + %3587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 15, !dbg !178 + %3588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 16, !dbg !178 + %3589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 17, !dbg !178 + %3590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 18, !dbg !178 + %3591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 19, !dbg !178 + %3592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 20, !dbg !178 + %3593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 21, !dbg !178 + %3594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 22, !dbg !178 + %3595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 23, !dbg !178 + %3596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 24, !dbg !178 + %3597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 25, !dbg !178 + %3598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 26, !dbg !178 + %3599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 27, !dbg !178 + %3600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 28, !dbg !178 + %3601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 29, !dbg !178 + %3602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 30, !dbg !178 + %3603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3561, 31, !dbg !178 + %3604 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3572, float %3573, float %3574, float %3575, float %3576, float %3577, float %3578, float %3579, float %3580, float %3581, float %3582, float %3583, float %3584, float %3585, float %3586, float %3587, float %3588, float %3589, float %3590, float %3591, float %3592, float %3593, float %3594, float %3595, float %3596, float %3597, float %3598, float %3599, float %3600, float %3601, float %3602, float %3603, i64 %3566, i64 %3571, i1 true) #3, !dbg !178 + %3605 = add i32 %2936, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3606 = lshr exact i32 %3605, 4, !dbg !178 + %3607 = and i32 %3606, 16383, !dbg !178 + %3608 = zext nneg i32 %3607 to i64, !dbg !178 + %3609 = or disjoint i64 %3608, 4611686293372403712, !dbg !178 + %3610 = add i32 %3427, 8224, !dbg !178 + %3611 = lshr exact i32 %3610, 4, !dbg !178 + %3612 = and i32 %3611, 16383, !dbg !178 + %3613 = zext nneg i32 %3612 to i64, !dbg !178 + %3614 = or disjoint i64 %3613, 4611686293338849280, !dbg !178 + %3615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 0, !dbg !178 + %3616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 1, !dbg !178 + %3617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 2, !dbg !178 + %3618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 3, !dbg !178 + %3619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 4, !dbg !178 + %3620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 5, !dbg !178 + %3621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 6, !dbg !178 + %3622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 7, !dbg !178 + %3623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 8, !dbg !178 + %3624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 9, !dbg !178 + %3625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 10, !dbg !178 + %3626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 11, !dbg !178 + %3627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 12, !dbg !178 + %3628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 13, !dbg !178 + %3629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 14, !dbg !178 + %3630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 15, !dbg !178 + %3631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 16, !dbg !178 + %3632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 17, !dbg !178 + %3633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 18, !dbg !178 + %3634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 19, !dbg !178 + %3635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 20, !dbg !178 + %3636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 21, !dbg !178 + %3637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 22, !dbg !178 + %3638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 23, !dbg !178 + %3639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 24, !dbg !178 + %3640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 25, !dbg !178 + %3641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 26, !dbg !178 + %3642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 27, !dbg !178 + %3643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 28, !dbg !178 + %3644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 29, !dbg !178 + %3645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 30, !dbg !178 + %3646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3604, 31, !dbg !178 + %3647 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3615, float %3616, float %3617, float %3618, float %3619, float %3620, float %3621, float %3622, float %3623, float %3624, float %3625, float %3626, float %3627, float %3628, float %3629, float %3630, float %3631, float %3632, float %3633, float %3634, float %3635, float %3636, float %3637, float %3638, float %3639, float %3640, float %3641, float %3642, float %3643, float %3644, float %3645, float %3646, i64 %3609, i64 %3614, i1 true) #3, !dbg !178 + %3648 = add i32 %2980, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3649 = lshr exact i32 %3648, 4, !dbg !178 + %3650 = and i32 %3649, 16383, !dbg !178 + %3651 = zext nneg i32 %3650 to i64, !dbg !178 + %3652 = or disjoint i64 %3651, 4611686293372403712, !dbg !178 + %3653 = add i32 %3427, 8256, !dbg !178 + %3654 = lshr exact i32 %3653, 4, !dbg !178 + %3655 = and i32 %3654, 16383, !dbg !178 + %3656 = zext nneg i32 %3655 to i64, !dbg !178 + %3657 = or disjoint i64 %3656, 4611686293338849280, !dbg !178 + %3658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 0, !dbg !178 + %3659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 1, !dbg !178 + %3660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 2, !dbg !178 + %3661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 3, !dbg !178 + %3662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 4, !dbg !178 + %3663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 5, !dbg !178 + %3664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 6, !dbg !178 + %3665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 7, !dbg !178 + %3666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 8, !dbg !178 + %3667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 9, !dbg !178 + %3668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 10, !dbg !178 + %3669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 11, !dbg !178 + %3670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 12, !dbg !178 + %3671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 13, !dbg !178 + %3672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 14, !dbg !178 + %3673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 15, !dbg !178 + %3674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 16, !dbg !178 + %3675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 17, !dbg !178 + %3676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 18, !dbg !178 + %3677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 19, !dbg !178 + %3678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 20, !dbg !178 + %3679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 21, !dbg !178 + %3680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 22, !dbg !178 + %3681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 23, !dbg !178 + %3682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 24, !dbg !178 + %3683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 25, !dbg !178 + %3684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 26, !dbg !178 + %3685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 27, !dbg !178 + %3686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 28, !dbg !178 + %3687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 29, !dbg !178 + %3688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 30, !dbg !178 + %3689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3647, 31, !dbg !178 + %3690 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3658, float %3659, float %3660, float %3661, float %3662, float %3663, float %3664, float %3665, float %3666, float %3667, float %3668, float %3669, float %3670, float %3671, float %3672, float %3673, float %3674, float %3675, float %3676, float %3677, float %3678, float %3679, float %3680, float %3681, float %3682, float %3683, float %3684, float %3685, float %3686, float %3687, float %3688, float %3689, i64 %3652, i64 %3657, i1 true) #3, !dbg !178 + %3691 = add i32 %3024, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072) to i32), !dbg !178 + %3692 = lshr exact i32 %3691, 4, !dbg !178 + %3693 = and i32 %3692, 16383, !dbg !178 + %3694 = zext nneg i32 %3693 to i64, !dbg !178 + %3695 = or disjoint i64 %3694, 4611686293372403712, !dbg !178 + %3696 = add i32 %3427, 8288, !dbg !178 + %3697 = lshr exact i32 %3696, 4, !dbg !178 + %3698 = and i32 %3697, 16383, !dbg !178 + %3699 = zext nneg i32 %3698 to i64, !dbg !178 + %3700 = or disjoint i64 %3699, 4611686293338849280, !dbg !178 + %3701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 0, !dbg !178 + %3702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 1, !dbg !178 + %3703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 2, !dbg !178 + %3704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 3, !dbg !178 + %3705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 4, !dbg !178 + %3706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 5, !dbg !178 + %3707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 6, !dbg !178 + %3708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 7, !dbg !178 + %3709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 8, !dbg !178 + %3710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 9, !dbg !178 + %3711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 10, !dbg !178 + %3712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 11, !dbg !178 + %3713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 12, !dbg !178 + %3714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 13, !dbg !178 + %3715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 14, !dbg !178 + %3716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 15, !dbg !178 + %3717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 16, !dbg !178 + %3718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 17, !dbg !178 + %3719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 18, !dbg !178 + %3720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 19, !dbg !178 + %3721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 20, !dbg !178 + %3722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 21, !dbg !178 + %3723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 22, !dbg !178 + %3724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 23, !dbg !178 + %3725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 24, !dbg !178 + %3726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 25, !dbg !178 + %3727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 26, !dbg !178 + %3728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 27, !dbg !178 + %3729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 28, !dbg !178 + %3730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 29, !dbg !178 + %3731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 30, !dbg !178 + %3732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3690, 31, !dbg !178 + %3733 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %3701, float %3702, float %3703, float %3704, float %3705, float %3706, float %3707, float %3708, float %3709, float %3710, float %3711, float %3712, float %3713, float %3714, float %3715, float %3716, float %3717, float %3718, float %3719, float %3720, float %3721, float %3722, float %3723, float %3724, float %3725, float %3726, float %3727, float %3728, float %3729, float %3730, float %3731, float %3732, i64 %3695, i64 %3700, i1 true) #3, !dbg !178 + %3734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 0, !dbg !178 + %3735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 1, !dbg !178 + %3736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 2, !dbg !178 + %3737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 3, !dbg !178 + %3738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 4, !dbg !178 + %3739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 5, !dbg !178 + %3740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 6, !dbg !178 + %3741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 7, !dbg !178 + %3742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 8, !dbg !178 + %3743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 9, !dbg !178 + %3744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 10, !dbg !178 + %3745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 11, !dbg !178 + %3746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 12, !dbg !178 + %3747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 13, !dbg !178 + %3748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 14, !dbg !178 + %3749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 15, !dbg !178 + %3750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 16, !dbg !178 + %3751 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 17, !dbg !178 + %3752 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 18, !dbg !178 + %3753 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 19, !dbg !178 + %3754 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 20, !dbg !178 + %3755 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 21, !dbg !178 + %3756 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 22, !dbg !178 + %3757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 23, !dbg !178 + %3758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 24, !dbg !178 + %3759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 25, !dbg !178 + %3760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 26, !dbg !178 + %3761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 27, !dbg !178 + %3762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 28, !dbg !178 + %3763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 29, !dbg !178 + %3764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 30, !dbg !178 + %3765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3733, 31, !dbg !178 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !178 + %3766 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %3734, float %3735, float %3736, float %3737, float %3738, float %3739, float %3740, float %3741, float %3742, float %3743, float %3744, float %3745, float %3746, float %3747, float %3748, float %3749, float %3750, float %3751, float %3752, float %3753, float %3754, float %3755, float %3756, float %3757, float %3758, float %3759, float %3760, float %3761, float %3762, float %3763, float %3764, float %3765, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 131072), i32 0, i32 0, ptr addrspace(3) %3421, i32 0, i32 0) #3, !dbg !178 + %3767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 0, !dbg !178 + %3768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 1, !dbg !178 + %3769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 2, !dbg !178 + %3770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 3, !dbg !178 + %3771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 4, !dbg !178 + %3772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 5, !dbg !178 + %3773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 6, !dbg !178 + %3774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 7, !dbg !178 + %3775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 8, !dbg !178 + %3776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 9, !dbg !178 + %3777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 10, !dbg !178 + %3778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 11, !dbg !178 + %3779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 12, !dbg !178 + %3780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 13, !dbg !178 + %3781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 14, !dbg !178 + %3782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 15, !dbg !178 + %3783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 16, !dbg !178 + %3784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 17, !dbg !178 + %3785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 18, !dbg !178 + %3786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 19, !dbg !178 + %3787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 20, !dbg !178 + %3788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 21, !dbg !178 + %3789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 22, !dbg !178 + %3790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 23, !dbg !178 + %3791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 24, !dbg !178 + %3792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 25, !dbg !178 + %3793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 26, !dbg !178 + %3794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 27, !dbg !178 + %3795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 28, !dbg !178 + %3796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 29, !dbg !178 + %3797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 30, !dbg !178 + %3798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %3766, 31, !dbg !178 + %3799 = fsub float %3767, %355, !dbg !179 + %3800 = fsub float %3768, %355, !dbg !179 + %3801 = fsub float %3769, %357, !dbg !179 + %3802 = fsub float %3770, %357, !dbg !179 + %3803 = fsub float %3771, %355, !dbg !179 + %3804 = fsub float %3772, %355, !dbg !179 + %3805 = fsub float %3773, %357, !dbg !179 + %3806 = fsub float %3774, %357, !dbg !179 + %3807 = fsub float %3775, %355, !dbg !179 + %3808 = fsub float %3776, %355, !dbg !179 + %3809 = fsub float %3777, %357, !dbg !179 + %3810 = fsub float %3778, %357, !dbg !179 + %3811 = fsub float %3779, %355, !dbg !179 + %3812 = fsub float %3780, %355, !dbg !179 + %3813 = fsub float %3781, %357, !dbg !179 + %3814 = fsub float %3782, %357, !dbg !179 + %3815 = fsub float %3783, %355, !dbg !179 + %3816 = fsub float %3784, %355, !dbg !179 + %3817 = fsub float %3785, %357, !dbg !179 + %3818 = fsub float %3786, %357, !dbg !179 + %3819 = fsub float %3787, %355, !dbg !179 + %3820 = fsub float %3788, %355, !dbg !179 + %3821 = fsub float %3789, %357, !dbg !179 + %3822 = fsub float %3790, %357, !dbg !179 + %3823 = fsub float %3791, %355, !dbg !179 + %3824 = fsub float %3792, %355, !dbg !179 + %3825 = fsub float %3793, %357, !dbg !179 + %3826 = fsub float %3794, %357, !dbg !179 + %3827 = fsub float %3795, %355, !dbg !179 + %3828 = fsub float %3796, %355, !dbg !179 + %3829 = fsub float %3797, %357, !dbg !179 + %3830 = fsub float %3798, %357, !dbg !179 + %3831 = fmul float %.0.i1419, %3799, !dbg !180 + %3832 = fmul float %.0.i1422, %3800, !dbg !180 + %3833 = fmul float %.0.i1425, %3801, !dbg !180 + %3834 = fmul float %.0.i1428, %3802, !dbg !180 + %3835 = fmul float %.0.i1431, %3803, !dbg !180 + %3836 = fmul float %.0.i1434, %3804, !dbg !180 + %3837 = fmul float %.0.i1437, %3805, !dbg !180 + %3838 = fmul float %.0.i1440, %3806, !dbg !180 + %3839 = fmul float %.0.i1443, %3807, !dbg !180 + %3840 = fmul float %.0.i1446, %3808, !dbg !180 + %3841 = fmul float %.0.i1449, %3809, !dbg !180 + %3842 = fmul float %.0.i1452, %3810, !dbg !180 + %3843 = fmul float %.0.i1455, %3811, !dbg !180 + %3844 = fmul float %.0.i1458, %3812, !dbg !180 + %3845 = fmul float %.0.i1461, %3813, !dbg !180 + %3846 = fmul float %.0.i1464, %3814, !dbg !180 + %3847 = fmul float %.0.i1467, %3815, !dbg !180 + %3848 = fmul float %.0.i1470, %3816, !dbg !180 + %3849 = fmul float %.0.i1473, %3817, !dbg !180 + %3850 = fmul float %.0.i1476, %3818, !dbg !180 + %3851 = fmul float %.0.i1479, %3819, !dbg !180 + %3852 = fmul float %.0.i1482, %3820, !dbg !180 + %3853 = fmul float %.0.i1485, %3821, !dbg !180 + %3854 = fmul float %.0.i1488, %3822, !dbg !180 + %3855 = fmul float %.0.i1491, %3823, !dbg !180 + %3856 = fmul float %.0.i1494, %3824, !dbg !180 + %3857 = fmul float %.0.i1497, %3825, !dbg !180 + %3858 = fmul float %.0.i1500, %3826, !dbg !180 + %3859 = fmul float %.0.i1503, %3827, !dbg !180 + %3860 = fmul float %.0.i1506, %3828, !dbg !180 + %3861 = fmul float %.0.i1509, %3829, !dbg !180 + %3862 = fmul float %.0.i1512, %3830, !dbg !180 + %3863 = fptrunc float %3831 to bfloat, !dbg !181 + %3864 = select i1 %2713, bfloat %3863, bfloat 0xR0000, !dbg !182 + %3865 = fptrunc float %3832 to bfloat, !dbg !181 + %3866 = select i1 %2715, bfloat %3865, bfloat 0xR0000, !dbg !182 + %3867 = fptrunc float %3833 to bfloat, !dbg !181 + %3868 = select i1 %2713, bfloat %3867, bfloat 0xR0000, !dbg !182 + %3869 = fptrunc float %3834 to bfloat, !dbg !181 + %3870 = select i1 %2715, bfloat %3869, bfloat 0xR0000, !dbg !182 + %3871 = fptrunc float %3835 to bfloat, !dbg !181 + %3872 = select i1 %2717, bfloat %3871, bfloat 0xR0000, !dbg !182 + %3873 = fptrunc float %3836 to bfloat, !dbg !181 + %3874 = select i1 %2719, bfloat %3873, bfloat 0xR0000, !dbg !182 + %3875 = fptrunc float %3837 to bfloat, !dbg !181 + %3876 = select i1 %2717, bfloat %3875, bfloat 0xR0000, !dbg !182 + %3877 = fptrunc float %3838 to bfloat, !dbg !181 + %3878 = select i1 %2719, bfloat %3877, bfloat 0xR0000, !dbg !182 + %3879 = fptrunc float %3839 to bfloat, !dbg !181 + %3880 = select i1 %2721, bfloat %3879, bfloat 0xR0000, !dbg !182 + %3881 = fptrunc float %3840 to bfloat, !dbg !181 + %3882 = select i1 %2723, bfloat %3881, bfloat 0xR0000, !dbg !182 + %3883 = fptrunc float %3841 to bfloat, !dbg !181 + %3884 = select i1 %2721, bfloat %3883, bfloat 0xR0000, !dbg !182 + %3885 = fptrunc float %3842 to bfloat, !dbg !181 + %3886 = select i1 %2723, bfloat %3885, bfloat 0xR0000, !dbg !182 + %3887 = fptrunc float %3843 to bfloat, !dbg !181 + %3888 = select i1 %2725, bfloat %3887, bfloat 0xR0000, !dbg !182 + %3889 = fptrunc float %3844 to bfloat, !dbg !181 + %3890 = select i1 %2727, bfloat %3889, bfloat 0xR0000, !dbg !182 + %3891 = fptrunc float %3845 to bfloat, !dbg !181 + %3892 = select i1 %2725, bfloat %3891, bfloat 0xR0000, !dbg !182 + %3893 = fptrunc float %3846 to bfloat, !dbg !181 + %3894 = select i1 %2727, bfloat %3893, bfloat 0xR0000, !dbg !182 + %3895 = fptrunc float %3847 to bfloat, !dbg !181 + %3896 = select i1 %2729, bfloat %3895, bfloat 0xR0000, !dbg !182 + %3897 = fptrunc float %3848 to bfloat, !dbg !181 + %3898 = select i1 %2731, bfloat %3897, bfloat 0xR0000, !dbg !182 + %3899 = fptrunc float %3849 to bfloat, !dbg !181 + %3900 = select i1 %2729, bfloat %3899, bfloat 0xR0000, !dbg !182 + %3901 = fptrunc float %3850 to bfloat, !dbg !181 + %3902 = select i1 %2731, bfloat %3901, bfloat 0xR0000, !dbg !182 + %3903 = fptrunc float %3851 to bfloat, !dbg !181 + %3904 = select i1 %2733, bfloat %3903, bfloat 0xR0000, !dbg !182 + %3905 = fptrunc float %3852 to bfloat, !dbg !181 + %3906 = select i1 %2735, bfloat %3905, bfloat 0xR0000, !dbg !182 + %3907 = fptrunc float %3853 to bfloat, !dbg !181 + %3908 = select i1 %2733, bfloat %3907, bfloat 0xR0000, !dbg !182 + %3909 = fptrunc float %3854 to bfloat, !dbg !181 + %3910 = select i1 %2735, bfloat %3909, bfloat 0xR0000, !dbg !182 + %3911 = fptrunc float %3855 to bfloat, !dbg !181 + %3912 = select i1 %2737, bfloat %3911, bfloat 0xR0000, !dbg !182 + %3913 = fptrunc float %3856 to bfloat, !dbg !181 + %3914 = select i1 %2739, bfloat %3913, bfloat 0xR0000, !dbg !182 + %3915 = fptrunc float %3857 to bfloat, !dbg !181 + %3916 = select i1 %2737, bfloat %3915, bfloat 0xR0000, !dbg !182 + %3917 = fptrunc float %3858 to bfloat, !dbg !181 + %3918 = select i1 %2739, bfloat %3917, bfloat 0xR0000, !dbg !182 + %3919 = fptrunc float %3859 to bfloat, !dbg !181 + %3920 = select i1 %2741, bfloat %3919, bfloat 0xR0000, !dbg !182 + %3921 = fptrunc float %3860 to bfloat, !dbg !181 + %3922 = select i1 %2743, bfloat %3921, bfloat 0xR0000, !dbg !182 + %3923 = fptrunc float %3861 to bfloat, !dbg !181 + %3924 = select i1 %2741, bfloat %3923, bfloat 0xR0000, !dbg !182 + %3925 = fptrunc float %3862 to bfloat, !dbg !181 + %3926 = select i1 %2743, bfloat %3925, bfloat 0xR0000, !dbg !182 + %3927 = insertelement <2 x bfloat> poison, bfloat %3864, i64 0, !dbg !183 + %3928 = insertelement <2 x bfloat> %3927, bfloat %3866, i64 1, !dbg !183 + %3929 = bitcast <2 x bfloat> %3928 to i32, !dbg !183 + %3930 = insertelement <2 x bfloat> poison, bfloat %3868, i64 0, !dbg !183 + %3931 = insertelement <2 x bfloat> %3930, bfloat %3870, i64 1, !dbg !183 + %3932 = bitcast <2 x bfloat> %3931 to i32, !dbg !183 + %3933 = insertelement <2 x bfloat> poison, bfloat %3872, i64 0, !dbg !183 + %3934 = insertelement <2 x bfloat> %3933, bfloat %3874, i64 1, !dbg !183 + %3935 = bitcast <2 x bfloat> %3934 to i32, !dbg !183 + %3936 = insertelement <2 x bfloat> poison, bfloat %3876, i64 0, !dbg !183 + %3937 = insertelement <2 x bfloat> %3936, bfloat %3878, i64 1, !dbg !183 + %3938 = bitcast <2 x bfloat> %3937 to i32, !dbg !183 + %3939 = insertelement <2 x bfloat> poison, bfloat %3880, i64 0, !dbg !183 + %3940 = insertelement <2 x bfloat> %3939, bfloat %3882, i64 1, !dbg !183 + %3941 = bitcast <2 x bfloat> %3940 to i32, !dbg !183 + %3942 = insertelement <2 x bfloat> poison, bfloat %3884, i64 0, !dbg !183 + %3943 = insertelement <2 x bfloat> %3942, bfloat %3886, i64 1, !dbg !183 + %3944 = bitcast <2 x bfloat> %3943 to i32, !dbg !183 + %3945 = insertelement <2 x bfloat> poison, bfloat %3888, i64 0, !dbg !183 + %3946 = insertelement <2 x bfloat> %3945, bfloat %3890, i64 1, !dbg !183 + %3947 = bitcast <2 x bfloat> %3946 to i32, !dbg !183 + %3948 = insertelement <2 x bfloat> poison, bfloat %3892, i64 0, !dbg !183 + %3949 = insertelement <2 x bfloat> %3948, bfloat %3894, i64 1, !dbg !183 + %3950 = bitcast <2 x bfloat> %3949 to i32, !dbg !183 + %3951 = insertelement <2 x bfloat> poison, bfloat %3896, i64 0, !dbg !183 + %3952 = insertelement <2 x bfloat> %3951, bfloat %3898, i64 1, !dbg !183 + %3953 = bitcast <2 x bfloat> %3952 to i32, !dbg !183 + %3954 = insertelement <2 x bfloat> poison, bfloat %3900, i64 0, !dbg !183 + %3955 = insertelement <2 x bfloat> %3954, bfloat %3902, i64 1, !dbg !183 + %3956 = bitcast <2 x bfloat> %3955 to i32, !dbg !183 + %3957 = insertelement <2 x bfloat> poison, bfloat %3904, i64 0, !dbg !183 + %3958 = insertelement <2 x bfloat> %3957, bfloat %3906, i64 1, !dbg !183 + %3959 = bitcast <2 x bfloat> %3958 to i32, !dbg !183 + %3960 = insertelement <2 x bfloat> poison, bfloat %3908, i64 0, !dbg !183 + %3961 = insertelement <2 x bfloat> %3960, bfloat %3910, i64 1, !dbg !183 + %3962 = bitcast <2 x bfloat> %3961 to i32, !dbg !183 + %3963 = insertelement <2 x bfloat> poison, bfloat %3912, i64 0, !dbg !183 + %3964 = insertelement <2 x bfloat> %3963, bfloat %3914, i64 1, !dbg !183 + %3965 = bitcast <2 x bfloat> %3964 to i32, !dbg !183 + %3966 = insertelement <2 x bfloat> poison, bfloat %3916, i64 0, !dbg !183 + %3967 = insertelement <2 x bfloat> %3966, bfloat %3918, i64 1, !dbg !183 + %3968 = bitcast <2 x bfloat> %3967 to i32, !dbg !183 + %3969 = insertelement <2 x bfloat> poison, bfloat %3920, i64 0, !dbg !183 + %3970 = insertelement <2 x bfloat> %3969, bfloat %3922, i64 1, !dbg !183 + %3971 = bitcast <2 x bfloat> %3970 to i32, !dbg !183 + %3972 = insertelement <2 x bfloat> poison, bfloat %3924, i64 0, !dbg !183 + %3973 = insertelement <2 x bfloat> %3972, bfloat %3926, i64 1, !dbg !183 + %3974 = bitcast <2 x bfloat> %3973 to i32, !dbg !183 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !183 + %3975 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn, float %.pn2530, float %.pn2531, float %.pn2532, float %.pn2533, float %.pn2534, float %.pn2535, float %.pn2536, float %.pn2537, float %.pn2538, float %.pn2539, float %.pn2540, float %.pn2541, float %.pn2542, float %.pn2543, float %.pn2544, float %.pn2545, float %.pn2546, float %.pn2547, float %.pn2548, float %.pn2549, float %.pn2550, float %.pn2551, float %.pn2552, float %.pn2553, float %.pn2554, float %.pn2555, float %.pn2556, float %.pn2557, float %.pn2558, float %.pn2559, float %.pn2560, float %.pn2561, float %.pn2562, float %.pn2563, float %.pn2564, float %.pn2565, float %.pn2566, float %.pn2567, float %.pn2568, float %.pn2569, float %.pn2570, float %.pn2571, float %.pn2572, float %.pn2573, float %.pn2574, float %.pn2575, float %.pn2576, float %.pn2577, float %.pn2578, float %.pn2579, float %.pn2580, float %.pn2581, float %.pn2582, float %.pn2583, float %.pn2584, float %.pn2585, float %.pn2586, float %.pn2587, float %.pn2588, float %.pn2589, float %.pn2590, float %.pn2591, float %.pn2592, i32 %3929, i32 %3932, i32 %3935, i32 %3938, i64 %2758, i1 true) #3, !dbg !183 + %3976 = add i32 %2754, 2048, !dbg !183 + %3977 = lshr exact i32 %3976, 4, !dbg !183 + %3978 = and i32 %3977, 16383, !dbg !183 + %3979 = zext nneg i32 %3978 to i64, !dbg !183 + %3980 = or disjoint i64 %3979, 4611686293338849280, !dbg !183 + %3981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 0, !dbg !183 + %3982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 1, !dbg !183 + %3983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 2, !dbg !183 + %3984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 3, !dbg !183 + %3985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 4, !dbg !183 + %3986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 5, !dbg !183 + %3987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 6, !dbg !183 + %3988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 7, !dbg !183 + %3989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 8, !dbg !183 + %3990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 9, !dbg !183 + %3991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 10, !dbg !183 + %3992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 11, !dbg !183 + %3993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 12, !dbg !183 + %3994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 13, !dbg !183 + %3995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 14, !dbg !183 + %3996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 15, !dbg !183 + %3997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 16, !dbg !183 + %3998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 17, !dbg !183 + %3999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 18, !dbg !183 + %4000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 19, !dbg !183 + %4001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 20, !dbg !183 + %4002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 21, !dbg !183 + %4003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 22, !dbg !183 + %4004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 23, !dbg !183 + %4005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 24, !dbg !183 + %4006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 25, !dbg !183 + %4007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 26, !dbg !183 + %4008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 27, !dbg !183 + %4009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 28, !dbg !183 + %4010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 29, !dbg !183 + %4011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 30, !dbg !183 + %4012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 31, !dbg !183 + %4013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 32, !dbg !183 + %4014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 33, !dbg !183 + %4015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 34, !dbg !183 + %4016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 35, !dbg !183 + %4017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 36, !dbg !183 + %4018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 37, !dbg !183 + %4019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 38, !dbg !183 + %4020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 39, !dbg !183 + %4021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 40, !dbg !183 + %4022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 41, !dbg !183 + %4023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 42, !dbg !183 + %4024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 43, !dbg !183 + %4025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 44, !dbg !183 + %4026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 45, !dbg !183 + %4027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 46, !dbg !183 + %4028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 47, !dbg !183 + %4029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 48, !dbg !183 + %4030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 49, !dbg !183 + %4031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 50, !dbg !183 + %4032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 51, !dbg !183 + %4033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 52, !dbg !183 + %4034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 53, !dbg !183 + %4035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 54, !dbg !183 + %4036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 55, !dbg !183 + %4037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 56, !dbg !183 + %4038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 57, !dbg !183 + %4039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 58, !dbg !183 + %4040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 59, !dbg !183 + %4041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 60, !dbg !183 + %4042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 61, !dbg !183 + %4043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 62, !dbg !183 + %4044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %3975, 63, !dbg !183 + %4045 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %3981, float %3982, float %3983, float %3984, float %3985, float %3986, float %3987, float %3988, float %3989, float %3990, float %3991, float %3992, float %3993, float %3994, float %3995, float %3996, float %3997, float %3998, float %3999, float %4000, float %4001, float %4002, float %4003, float %4004, float %4005, float %4006, float %4007, float %4008, float %4009, float %4010, float %4011, float %4012, float %4013, float %4014, float %4015, float %4016, float %4017, float %4018, float %4019, float %4020, float %4021, float %4022, float %4023, float %4024, float %4025, float %4026, float %4027, float %4028, float %4029, float %4030, float %4031, float %4032, float %4033, float %4034, float %4035, float %4036, float %4037, float %4038, float %4039, float %4040, float %4041, float %4042, float %4043, float %4044, i32 %3941, i32 %3944, i32 %3947, i32 %3950, i64 %3980, i1 true) #3, !dbg !183 + %4046 = add i32 %2754, 4096, !dbg !183 + %4047 = lshr exact i32 %4046, 4, !dbg !183 + %4048 = and i32 %4047, 16383, !dbg !183 + %4049 = zext nneg i32 %4048 to i64, !dbg !183 + %4050 = or disjoint i64 %4049, 4611686293338849280, !dbg !183 + %4051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 0, !dbg !183 + %4052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 1, !dbg !183 + %4053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 2, !dbg !183 + %4054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 3, !dbg !183 + %4055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 4, !dbg !183 + %4056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 5, !dbg !183 + %4057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 6, !dbg !183 + %4058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 7, !dbg !183 + %4059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 8, !dbg !183 + %4060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 9, !dbg !183 + %4061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 10, !dbg !183 + %4062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 11, !dbg !183 + %4063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 12, !dbg !183 + %4064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 13, !dbg !183 + %4065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 14, !dbg !183 + %4066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 15, !dbg !183 + %4067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 16, !dbg !183 + %4068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 17, !dbg !183 + %4069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 18, !dbg !183 + %4070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 19, !dbg !183 + %4071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 20, !dbg !183 + %4072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 21, !dbg !183 + %4073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 22, !dbg !183 + %4074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 23, !dbg !183 + %4075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 24, !dbg !183 + %4076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 25, !dbg !183 + %4077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 26, !dbg !183 + %4078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 27, !dbg !183 + %4079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 28, !dbg !183 + %4080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 29, !dbg !183 + %4081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 30, !dbg !183 + %4082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 31, !dbg !183 + %4083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 32, !dbg !183 + %4084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 33, !dbg !183 + %4085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 34, !dbg !183 + %4086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 35, !dbg !183 + %4087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 36, !dbg !183 + %4088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 37, !dbg !183 + %4089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 38, !dbg !183 + %4090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 39, !dbg !183 + %4091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 40, !dbg !183 + %4092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 41, !dbg !183 + %4093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 42, !dbg !183 + %4094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 43, !dbg !183 + %4095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 44, !dbg !183 + %4096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 45, !dbg !183 + %4097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 46, !dbg !183 + %4098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 47, !dbg !183 + %4099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 48, !dbg !183 + %4100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 49, !dbg !183 + %4101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 50, !dbg !183 + %4102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 51, !dbg !183 + %4103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 52, !dbg !183 + %4104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 53, !dbg !183 + %4105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 54, !dbg !183 + %4106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 55, !dbg !183 + %4107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 56, !dbg !183 + %4108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 57, !dbg !183 + %4109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 58, !dbg !183 + %4110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 59, !dbg !183 + %4111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 60, !dbg !183 + %4112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 61, !dbg !183 + %4113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 62, !dbg !183 + %4114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4045, 63, !dbg !183 + %4115 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %4051, float %4052, float %4053, float %4054, float %4055, float %4056, float %4057, float %4058, float %4059, float %4060, float %4061, float %4062, float %4063, float %4064, float %4065, float %4066, float %4067, float %4068, float %4069, float %4070, float %4071, float %4072, float %4073, float %4074, float %4075, float %4076, float %4077, float %4078, float %4079, float %4080, float %4081, float %4082, float %4083, float %4084, float %4085, float %4086, float %4087, float %4088, float %4089, float %4090, float %4091, float %4092, float %4093, float %4094, float %4095, float %4096, float %4097, float %4098, float %4099, float %4100, float %4101, float %4102, float %4103, float %4104, float %4105, float %4106, float %4107, float %4108, float %4109, float %4110, float %4111, float %4112, float %4113, float %4114, i32 %3953, i32 %3956, i32 %3959, i32 %3962, i64 %4050, i1 true) #3, !dbg !183 + %4116 = add i32 %2754, 6144, !dbg !183 + %4117 = lshr exact i32 %4116, 4, !dbg !183 + %4118 = and i32 %4117, 16383, !dbg !183 + %4119 = zext nneg i32 %4118 to i64, !dbg !183 + %4120 = or disjoint i64 %4119, 4611686293338849280, !dbg !183 + %4121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 0, !dbg !183 + %4122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 1, !dbg !183 + %4123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 2, !dbg !183 + %4124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 3, !dbg !183 + %4125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 4, !dbg !183 + %4126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 5, !dbg !183 + %4127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 6, !dbg !183 + %4128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 7, !dbg !183 + %4129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 8, !dbg !183 + %4130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 9, !dbg !183 + %4131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 10, !dbg !183 + %4132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 11, !dbg !183 + %4133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 12, !dbg !183 + %4134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 13, !dbg !183 + %4135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 14, !dbg !183 + %4136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 15, !dbg !183 + %4137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 16, !dbg !183 + %4138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 17, !dbg !183 + %4139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 18, !dbg !183 + %4140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 19, !dbg !183 + %4141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 20, !dbg !183 + %4142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 21, !dbg !183 + %4143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 22, !dbg !183 + %4144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 23, !dbg !183 + %4145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 24, !dbg !183 + %4146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 25, !dbg !183 + %4147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 26, !dbg !183 + %4148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 27, !dbg !183 + %4149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 28, !dbg !183 + %4150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 29, !dbg !183 + %4151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 30, !dbg !183 + %4152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 31, !dbg !183 + %4153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 32, !dbg !183 + %4154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 33, !dbg !183 + %4155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 34, !dbg !183 + %4156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 35, !dbg !183 + %4157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 36, !dbg !183 + %4158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 37, !dbg !183 + %4159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 38, !dbg !183 + %4160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 39, !dbg !183 + %4161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 40, !dbg !183 + %4162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 41, !dbg !183 + %4163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 42, !dbg !183 + %4164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 43, !dbg !183 + %4165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 44, !dbg !183 + %4166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 45, !dbg !183 + %4167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 46, !dbg !183 + %4168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 47, !dbg !183 + %4169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 48, !dbg !183 + %4170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 49, !dbg !183 + %4171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 50, !dbg !183 + %4172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 51, !dbg !183 + %4173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 52, !dbg !183 + %4174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 53, !dbg !183 + %4175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 54, !dbg !183 + %4176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 55, !dbg !183 + %4177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 56, !dbg !183 + %4178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 57, !dbg !183 + %4179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 58, !dbg !183 + %4180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 59, !dbg !183 + %4181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 60, !dbg !183 + %4182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 61, !dbg !183 + %4183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 62, !dbg !183 + %4184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4115, 63, !dbg !183 + %4185 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %4121, float %4122, float %4123, float %4124, float %4125, float %4126, float %4127, float %4128, float %4129, float %4130, float %4131, float %4132, float %4133, float %4134, float %4135, float %4136, float %4137, float %4138, float %4139, float %4140, float %4141, float %4142, float %4143, float %4144, float %4145, float %4146, float %4147, float %4148, float %4149, float %4150, float %4151, float %4152, float %4153, float %4154, float %4155, float %4156, float %4157, float %4158, float %4159, float %4160, float %4161, float %4162, float %4163, float %4164, float %4165, float %4166, float %4167, float %4168, float %4169, float %4170, float %4171, float %4172, float %4173, float %4174, float %4175, float %4176, float %4177, float %4178, float %4179, float %4180, float %4181, float %4182, float %4183, float %4184, i32 %3965, i32 %3968, i32 %3971, i32 %3974, i64 %4120, i1 true) #3, !dbg !183 + %4186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 0, !dbg !183 + %4187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 1, !dbg !183 + %4188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 2, !dbg !183 + %4189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 3, !dbg !183 + %4190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 4, !dbg !183 + %4191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 5, !dbg !183 + %4192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 6, !dbg !183 + %4193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 7, !dbg !183 + %4194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 8, !dbg !183 + %4195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 9, !dbg !183 + %4196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 10, !dbg !183 + %4197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 11, !dbg !183 + %4198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 12, !dbg !183 + %4199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 13, !dbg !183 + %4200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 14, !dbg !183 + %4201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 15, !dbg !183 + %4202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 16, !dbg !183 + %4203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 17, !dbg !183 + %4204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 18, !dbg !183 + %4205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 19, !dbg !183 + %4206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 20, !dbg !183 + %4207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 21, !dbg !183 + %4208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 22, !dbg !183 + %4209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 23, !dbg !183 + %4210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 24, !dbg !183 + %4211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 25, !dbg !183 + %4212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 26, !dbg !183 + %4213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 27, !dbg !183 + %4214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 28, !dbg !183 + %4215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 29, !dbg !183 + %4216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 30, !dbg !183 + %4217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 31, !dbg !183 + %4218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 32, !dbg !183 + %4219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 33, !dbg !183 + %4220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 34, !dbg !183 + %4221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 35, !dbg !183 + %4222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 36, !dbg !183 + %4223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 37, !dbg !183 + %4224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 38, !dbg !183 + %4225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 39, !dbg !183 + %4226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 40, !dbg !183 + %4227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 41, !dbg !183 + %4228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 42, !dbg !183 + %4229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 43, !dbg !183 + %4230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 44, !dbg !183 + %4231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 45, !dbg !183 + %4232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 46, !dbg !183 + %4233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 47, !dbg !183 + %4234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 48, !dbg !183 + %4235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 49, !dbg !183 + %4236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 50, !dbg !183 + %4237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 51, !dbg !183 + %4238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 52, !dbg !183 + %4239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 53, !dbg !183 + %4240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 54, !dbg !183 + %4241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 55, !dbg !183 + %4242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 56, !dbg !183 + %4243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 57, !dbg !183 + %4244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 58, !dbg !183 + %4245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 59, !dbg !183 + %4246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 60, !dbg !183 + %4247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 61, !dbg !183 + %4248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 62, !dbg !183 + %4249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4185, 63, !dbg !183 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !183 + %4250 = insertelement <16 x i32> poison, i32 %2702, i64 0, !dbg !171 + %4251 = shufflevector <16 x i32> %4250, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !171 + %4252 = add <16 x i32> %4251, %2706, !dbg !171 + %4253 = add nuw nsw i32 %2705, 1, !dbg !166 + %4254 = lshr i32 %4253, 1, !dbg !184 + %4255 = zext nneg i32 %4254 to i64, !dbg !185 + %4256 = getelementptr i32, ptr addrspace(1) %2488, i64 %4255, !dbg !185 + %4257 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !186 + %4258 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %4256, i64 %4257, i1 %2708) #3, !dbg !186 + %4259 = add nuw nsw i32 %4254, 1, !dbg !187 + %4260 = icmp slt i32 %4259, %2492, !dbg !188 + %4261 = getelementptr i8, ptr addrspace(1) %4256, i64 4, !dbg !189 + %4262 = and i1 %2708, %4260, !dbg !166 + %4263 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !190 + %4264 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %4261, i64 %4263, i1 %4262) #3, !dbg !190 + %4265 = and i32 %2705, 1, !dbg !191 + %4266 = sub i32 %4264, %4258, !dbg !192 + %4267 = shl i32 %4266, 7, !dbg !193 + %4268 = add i32 %4267, -64, !dbg !194 + %4269 = xor i32 %4265, 1, !dbg !195 + %4270 = mul nuw nsw i32 %4268, %4269, !dbg !195 + %4271 = shl nuw nsw i32 %4265, 6, !dbg !196 + %4272 = add i32 %4270, %4271, !dbg !197 + %4273 = shl i32 %4272, 10, !dbg !198 + %4274 = sext i32 %4273 to i64, !dbg !169 + %4275 = getelementptr bfloat, ptr addrspace(1) %.pn10771643, i64 %4274, !dbg !169 + %4276 = getelementptr bfloat, ptr addrspace(1) %.pn10611644, i64 %4274, !dbg !169 + %4277 = getelementptr bfloat, ptr addrspace(1) %.pn10451645, i64 %4274, !dbg !169 + %4278 = getelementptr bfloat, ptr addrspace(1) %.pn10291646, i64 %4274, !dbg !169 + %4279 = getelementptr bfloat, ptr addrspace(1) %.pn11491651, i64 %4274, !dbg !170 + %4280 = getelementptr bfloat, ptr addrspace(1) %.pn11331652, i64 %4274, !dbg !170 + %4281 = getelementptr bfloat, ptr addrspace(1) %.pn11171653, i64 %4274, !dbg !170 + %4282 = getelementptr bfloat, ptr addrspace(1) %.pn11011654, i64 %4274, !dbg !170 + %4283 = add i32 %4272, %.pn10851647, !dbg !171 + %4284 = add i32 %4272, %.pn10831648, !dbg !171 + %4285 = add i32 %4272, %.pn10811649, !dbg !171 + %4286 = add i32 %4272, %.pn10791650, !dbg !171 + %4287 = add i32 %2704, 1, !dbg !166 + %4288 = icmp sgt i32 %4287, 2, !dbg !166 + %4289 = select i1 %4288, i32 0, i32 %4287, !dbg !166 + %4290 = icmp slt i32 %4283, %18, !dbg !167 + %4291 = icmp slt i32 %4284, %18, !dbg !167 + %4292 = icmp slt i32 %4285, %18, !dbg !167 + %4293 = icmp slt i32 %4286, %18, !dbg !167 + %4294 = shl i32 %4289, 13, !dbg !168 + %4295 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %4294, !dbg !168 + %4296 = and i1 %2707, %4290, !dbg !166 + %4297 = and i1 %2707, %4291, !dbg !166 + %4298 = and i1 %2707, %4292, !dbg !166 + %4299 = and i1 %2707, %4293, !dbg !166 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !168 + %4300 = getelementptr inbounds nuw i8, ptr addrspace(3) %4295, i32 %429, !dbg !168 + %4301 = select i1 %4296, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4300, ptr addrspace(1) %4275, i32 %4301) #3, !dbg !168 + %4302 = getelementptr inbounds nuw i8, ptr addrspace(3) %4295, i32 %432, !dbg !168 + %4303 = select i1 %4297, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4302, ptr addrspace(1) %4276, i32 %4303) #3, !dbg !168 + %4304 = getelementptr inbounds nuw i8, ptr addrspace(3) %4295, i32 %435, !dbg !168 + %4305 = select i1 %4298, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4304, ptr addrspace(1) %4277, i32 %4305) #3, !dbg !168 + %4306 = getelementptr inbounds nuw i8, ptr addrspace(3) %4295, i32 %438, !dbg !168 + %4307 = select i1 %4299, i32 16, i32 0, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4306, ptr addrspace(1) %4278, i32 %4307) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + %4308 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %4294, !dbg !168 + %4309 = getelementptr inbounds nuw i8, ptr addrspace(3) %4308, i32 %429, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %4309, ptr addrspace(1) %4279, i32 %4301) #3, !dbg !168 + %4310 = getelementptr inbounds nuw i8, ptr addrspace(3) %4308, i32 %432, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4310, ptr addrspace(1) %4280, i32 %4303) #3, !dbg !168 + %4311 = getelementptr inbounds nuw i8, ptr addrspace(3) %4308, i32 %435, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4311, ptr addrspace(1) %4281, i32 %4305) #3, !dbg !168 + %4312 = getelementptr inbounds nuw i8, ptr addrspace(3) %4308, i32 %438, !dbg !168 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %4312, ptr addrspace(1) %4282, i32 %4307) #3, !dbg !168 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !168 + %exitcond2264.not = icmp eq i32 %4253, %smax2263, !dbg !166 + br i1 %exitcond2264.not, label %._crit_edge1673, label %2701, !dbg !166 + +._crit_edge1673: ; preds = %__nv_exp2f.exit1513, %._crit_edge + %4313 = phi float [ %2561, %._crit_edge ], [ %4186, %__nv_exp2f.exit1513 ], !dbg !91 + %4314 = phi float [ %2562, %._crit_edge ], [ %4187, %__nv_exp2f.exit1513 ], !dbg !91 + %4315 = phi float [ %2563, %._crit_edge ], [ %4188, %__nv_exp2f.exit1513 ], !dbg !91 + %4316 = phi float [ %2564, %._crit_edge ], [ %4189, %__nv_exp2f.exit1513 ], !dbg !91 + %4317 = phi float [ %2565, %._crit_edge ], [ %4190, %__nv_exp2f.exit1513 ], !dbg !91 + %4318 = phi float [ %2566, %._crit_edge ], [ %4191, %__nv_exp2f.exit1513 ], !dbg !91 + %4319 = phi float [ %2567, %._crit_edge ], [ %4192, %__nv_exp2f.exit1513 ], !dbg !91 + %4320 = phi float [ %2568, %._crit_edge ], [ %4193, %__nv_exp2f.exit1513 ], !dbg !91 + %4321 = phi float [ %2569, %._crit_edge ], [ %4194, %__nv_exp2f.exit1513 ], !dbg !91 + %4322 = phi float [ %2570, %._crit_edge ], [ %4195, %__nv_exp2f.exit1513 ], !dbg !91 + %4323 = phi float [ %2571, %._crit_edge ], [ %4196, %__nv_exp2f.exit1513 ], !dbg !91 + %4324 = phi float [ %2572, %._crit_edge ], [ %4197, %__nv_exp2f.exit1513 ], !dbg !91 + %4325 = phi float [ %2573, %._crit_edge ], [ %4198, %__nv_exp2f.exit1513 ], !dbg !91 + %4326 = phi float [ %2574, %._crit_edge ], [ %4199, %__nv_exp2f.exit1513 ], !dbg !91 + %4327 = phi float [ %2575, %._crit_edge ], [ %4200, %__nv_exp2f.exit1513 ], !dbg !91 + %4328 = phi float [ %2576, %._crit_edge ], [ %4201, %__nv_exp2f.exit1513 ], !dbg !91 + %4329 = phi float [ %2577, %._crit_edge ], [ %4202, %__nv_exp2f.exit1513 ], !dbg !91 + %4330 = phi float [ %2578, %._crit_edge ], [ %4203, %__nv_exp2f.exit1513 ], !dbg !91 + %4331 = phi float [ %2579, %._crit_edge ], [ %4204, %__nv_exp2f.exit1513 ], !dbg !91 + %4332 = phi float [ %2580, %._crit_edge ], [ %4205, %__nv_exp2f.exit1513 ], !dbg !91 + %4333 = phi float [ %2581, %._crit_edge ], [ %4206, %__nv_exp2f.exit1513 ], !dbg !91 + %4334 = phi float [ %2582, %._crit_edge ], [ %4207, %__nv_exp2f.exit1513 ], !dbg !91 + %4335 = phi float [ %2583, %._crit_edge ], [ %4208, %__nv_exp2f.exit1513 ], !dbg !91 + %4336 = phi float [ %2584, %._crit_edge ], [ %4209, %__nv_exp2f.exit1513 ], !dbg !91 + %4337 = phi float [ %2585, %._crit_edge ], [ %4210, %__nv_exp2f.exit1513 ], !dbg !91 + %4338 = phi float [ %2586, %._crit_edge ], [ %4211, %__nv_exp2f.exit1513 ], !dbg !91 + %4339 = phi float [ %2587, %._crit_edge ], [ %4212, %__nv_exp2f.exit1513 ], !dbg !91 + %4340 = phi float [ %2588, %._crit_edge ], [ %4213, %__nv_exp2f.exit1513 ], !dbg !91 + %4341 = phi float [ %2589, %._crit_edge ], [ %4214, %__nv_exp2f.exit1513 ], !dbg !91 + %4342 = phi float [ %2590, %._crit_edge ], [ %4215, %__nv_exp2f.exit1513 ], !dbg !91 + %4343 = phi float [ %2591, %._crit_edge ], [ %4216, %__nv_exp2f.exit1513 ], !dbg !91 + %4344 = phi float [ %2592, %._crit_edge ], [ %4217, %__nv_exp2f.exit1513 ], !dbg !91 + %4345 = phi float [ %2593, %._crit_edge ], [ %4218, %__nv_exp2f.exit1513 ], !dbg !91 + %4346 = phi float [ %2594, %._crit_edge ], [ %4219, %__nv_exp2f.exit1513 ], !dbg !91 + %4347 = phi float [ %2595, %._crit_edge ], [ %4220, %__nv_exp2f.exit1513 ], !dbg !91 + %4348 = phi float [ %2596, %._crit_edge ], [ %4221, %__nv_exp2f.exit1513 ], !dbg !91 + %4349 = phi float [ %2597, %._crit_edge ], [ %4222, %__nv_exp2f.exit1513 ], !dbg !91 + %4350 = phi float [ %2598, %._crit_edge ], [ %4223, %__nv_exp2f.exit1513 ], !dbg !91 + %4351 = phi float [ %2599, %._crit_edge ], [ %4224, %__nv_exp2f.exit1513 ], !dbg !91 + %4352 = phi float [ %2600, %._crit_edge ], [ %4225, %__nv_exp2f.exit1513 ], !dbg !91 + %4353 = phi float [ %2601, %._crit_edge ], [ %4226, %__nv_exp2f.exit1513 ], !dbg !91 + %4354 = phi float [ %2602, %._crit_edge ], [ %4227, %__nv_exp2f.exit1513 ], !dbg !91 + %4355 = phi float [ %2603, %._crit_edge ], [ %4228, %__nv_exp2f.exit1513 ], !dbg !91 + %4356 = phi float [ %2604, %._crit_edge ], [ %4229, %__nv_exp2f.exit1513 ], !dbg !91 + %4357 = phi float [ %2605, %._crit_edge ], [ %4230, %__nv_exp2f.exit1513 ], !dbg !91 + %4358 = phi float [ %2606, %._crit_edge ], [ %4231, %__nv_exp2f.exit1513 ], !dbg !91 + %4359 = phi float [ %2607, %._crit_edge ], [ %4232, %__nv_exp2f.exit1513 ], !dbg !91 + %4360 = phi float [ %2608, %._crit_edge ], [ %4233, %__nv_exp2f.exit1513 ], !dbg !91 + %4361 = phi float [ %2609, %._crit_edge ], [ %4234, %__nv_exp2f.exit1513 ], !dbg !91 + %4362 = phi float [ %2610, %._crit_edge ], [ %4235, %__nv_exp2f.exit1513 ], !dbg !91 + %4363 = phi float [ %2611, %._crit_edge ], [ %4236, %__nv_exp2f.exit1513 ], !dbg !91 + %4364 = phi float [ %2612, %._crit_edge ], [ %4237, %__nv_exp2f.exit1513 ], !dbg !91 + %4365 = phi float [ %2613, %._crit_edge ], [ %4238, %__nv_exp2f.exit1513 ], !dbg !91 + %4366 = phi float [ %2614, %._crit_edge ], [ %4239, %__nv_exp2f.exit1513 ], !dbg !91 + %4367 = phi float [ %2615, %._crit_edge ], [ %4240, %__nv_exp2f.exit1513 ], !dbg !91 + %4368 = phi float [ %2616, %._crit_edge ], [ %4241, %__nv_exp2f.exit1513 ], !dbg !91 + %4369 = phi float [ %2617, %._crit_edge ], [ %4242, %__nv_exp2f.exit1513 ], !dbg !91 + %4370 = phi float [ %2618, %._crit_edge ], [ %4243, %__nv_exp2f.exit1513 ], !dbg !91 + %4371 = phi float [ %2619, %._crit_edge ], [ %4244, %__nv_exp2f.exit1513 ], !dbg !91 + %4372 = phi float [ %2620, %._crit_edge ], [ %4245, %__nv_exp2f.exit1513 ], !dbg !91 + %4373 = phi float [ %2621, %._crit_edge ], [ %4246, %__nv_exp2f.exit1513 ], !dbg !91 + %4374 = phi float [ %2622, %._crit_edge ], [ %4247, %__nv_exp2f.exit1513 ], !dbg !91 + %4375 = phi float [ %2623, %._crit_edge ], [ %4248, %__nv_exp2f.exit1513 ], !dbg !91 + %4376 = phi float [ %2624, %._crit_edge ], [ %4249, %__nv_exp2f.exit1513 ], !dbg !91 + %4377 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63"(float %4313, float %4314, float %4315, float %4316, float %4317, float %4318, float %4319, float %4320, float %4321, float %4322, float %4323, float %4324, float %4325, float %4326, float %4327, float %4328, float %4329, float %4330, float %4331, float %4332, float %4333, float %4334, float %4335, float %4336, float %4337, float %4338, float %4339, float %4340, float %4341, float %4342, float %4343, float %4344, float %4345, float %4346, float %4347, float %4348, float %4349, float %4350, float %4351, float %4352, float %4353, float %4354, float %4355, float %4356, float %4357, float %4358, float %4359, float %4360, float %4361, float %4362, float %4363, float %4364, float %4365, float %4366, float %4367, float %4368, float %4369, float %4370, float %4371, float %4372, float %4373, float %4374, float %4375, float %4376) #3, !dbg !166 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !166 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !166 + %4378 = getelementptr bfloat, ptr addrspace(1) %83, i64 %105, !dbg !199 + %4379 = getelementptr bfloat, ptr addrspace(1) %83, i64 %107, !dbg !199 + %4380 = getelementptr bfloat, ptr addrspace(1) %83, i64 %109, !dbg !199 + %4381 = getelementptr bfloat, ptr addrspace(1) %83, i64 %111, !dbg !199 + %4382 = getelementptr bfloat, ptr addrspace(1) %83, i64 %113, !dbg !199 + %4383 = getelementptr bfloat, ptr addrspace(1) %83, i64 %115, !dbg !199 + %4384 = getelementptr bfloat, ptr addrspace(1) %83, i64 %117, !dbg !199 + %4385 = getelementptr bfloat, ptr addrspace(1) %83, i64 %119, !dbg !199 + %4386 = getelementptr bfloat, ptr addrspace(1) %4378, i64 %123, !dbg !200 + %4387 = getelementptr bfloat, ptr addrspace(1) %4379, i64 %123, !dbg !200 + %4388 = getelementptr bfloat, ptr addrspace(1) %4380, i64 %123, !dbg !200 + %4389 = getelementptr bfloat, ptr addrspace(1) %4381, i64 %123, !dbg !200 + %4390 = getelementptr bfloat, ptr addrspace(1) %4382, i64 %123, !dbg !200 + %4391 = getelementptr bfloat, ptr addrspace(1) %4383, i64 %123, !dbg !200 + %4392 = getelementptr bfloat, ptr addrspace(1) %4384, i64 %123, !dbg !200 + %4393 = getelementptr bfloat, ptr addrspace(1) %4385, i64 %123, !dbg !200 + %4394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 0, !dbg !201 + %4395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 1, !dbg !201 + %4396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 2, !dbg !201 + %4397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 3, !dbg !201 + %4398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 4, !dbg !201 + %4399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 5, !dbg !201 + %4400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 6, !dbg !201 + %4401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 7, !dbg !201 + %4402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 8, !dbg !201 + %4403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 9, !dbg !201 + %4404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 10, !dbg !201 + %4405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 11, !dbg !201 + %4406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 12, !dbg !201 + %4407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 13, !dbg !201 + %4408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 14, !dbg !201 + %4409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 15, !dbg !201 + %4410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 16, !dbg !201 + %4411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 17, !dbg !201 + %4412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 18, !dbg !201 + %4413 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 19, !dbg !201 + %4414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 20, !dbg !201 + %4415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 21, !dbg !201 + %4416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 22, !dbg !201 + %4417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 23, !dbg !201 + %4418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 24, !dbg !201 + %4419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 25, !dbg !201 + %4420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 26, !dbg !201 + %4421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 27, !dbg !201 + %4422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 28, !dbg !201 + %4423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 29, !dbg !201 + %4424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 30, !dbg !201 + %4425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 31, !dbg !201 + %4426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 32, !dbg !201 + %4427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 33, !dbg !201 + %4428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 34, !dbg !201 + %4429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 35, !dbg !201 + %4430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 36, !dbg !201 + %4431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 37, !dbg !201 + %4432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 38, !dbg !201 + %4433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 39, !dbg !201 + %4434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 40, !dbg !201 + %4435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 41, !dbg !201 + %4436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 42, !dbg !201 + %4437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 43, !dbg !201 + %4438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 44, !dbg !201 + %4439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 45, !dbg !201 + %4440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 46, !dbg !201 + %4441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 47, !dbg !201 + %4442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 48, !dbg !201 + %4443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 49, !dbg !201 + %4444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 50, !dbg !201 + %4445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 51, !dbg !201 + %4446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 52, !dbg !201 + %4447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 53, !dbg !201 + %4448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 54, !dbg !201 + %4449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 55, !dbg !201 + %4450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 56, !dbg !201 + %4451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 57, !dbg !201 + %4452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 58, !dbg !201 + %4453 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 59, !dbg !201 + %4454 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 60, !dbg !201 + %4455 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 61, !dbg !201 + %4456 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 62, !dbg !201 + %4457 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %4377, 63, !dbg !201 + %4458 = insertelement <2 x float> poison, float %4394, i64 0, !dbg !201 + %4459 = insertelement <2 x float> %4458, float %4395, i64 1, !dbg !201 + %4460 = fmul <2 x float> %4459, splat (float 0x3FB6A09E60000000), !dbg !201 + %4461 = fptrunc <2 x float> %4460 to <2 x bfloat>, !dbg !202 + %4462 = insertelement <2 x float> poison, float %4396, i64 0, !dbg !201 + %4463 = insertelement <2 x float> %4462, float %4397, i64 1, !dbg !201 + %4464 = fmul <2 x float> %4463, splat (float 0x3FB6A09E60000000), !dbg !201 + %4465 = fptrunc <2 x float> %4464 to <2 x bfloat>, !dbg !202 + %4466 = insertelement <2 x float> poison, float %4398, i64 0, !dbg !201 + %4467 = insertelement <2 x float> %4466, float %4399, i64 1, !dbg !201 + %4468 = fmul <2 x float> %4467, splat (float 0x3FB6A09E60000000), !dbg !201 + %4469 = fptrunc <2 x float> %4468 to <2 x bfloat>, !dbg !202 + %4470 = insertelement <2 x float> poison, float %4400, i64 0, !dbg !201 + %4471 = insertelement <2 x float> %4470, float %4401, i64 1, !dbg !201 + %4472 = fmul <2 x float> %4471, splat (float 0x3FB6A09E60000000), !dbg !201 + %4473 = fptrunc <2 x float> %4472 to <2 x bfloat>, !dbg !202 + %4474 = insertelement <2 x float> poison, float %4402, i64 0, !dbg !201 + %4475 = insertelement <2 x float> %4474, float %4403, i64 1, !dbg !201 + %4476 = fmul <2 x float> %4475, splat (float 0x3FB6A09E60000000), !dbg !201 + %4477 = fptrunc <2 x float> %4476 to <2 x bfloat>, !dbg !202 + %4478 = insertelement <2 x float> poison, float %4404, i64 0, !dbg !201 + %4479 = insertelement <2 x float> %4478, float %4405, i64 1, !dbg !201 + %4480 = fmul <2 x float> %4479, splat (float 0x3FB6A09E60000000), !dbg !201 + %4481 = fptrunc <2 x float> %4480 to <2 x bfloat>, !dbg !202 + %4482 = insertelement <2 x float> poison, float %4406, i64 0, !dbg !201 + %4483 = insertelement <2 x float> %4482, float %4407, i64 1, !dbg !201 + %4484 = fmul <2 x float> %4483, splat (float 0x3FB6A09E60000000), !dbg !201 + %4485 = fptrunc <2 x float> %4484 to <2 x bfloat>, !dbg !202 + %4486 = insertelement <2 x float> poison, float %4408, i64 0, !dbg !201 + %4487 = insertelement <2 x float> %4486, float %4409, i64 1, !dbg !201 + %4488 = fmul <2 x float> %4487, splat (float 0x3FB6A09E60000000), !dbg !201 + %4489 = fptrunc <2 x float> %4488 to <2 x bfloat>, !dbg !202 + %4490 = insertelement <2 x float> poison, float %4410, i64 0, !dbg !201 + %4491 = insertelement <2 x float> %4490, float %4411, i64 1, !dbg !201 + %4492 = fmul <2 x float> %4491, splat (float 0x3FB6A09E60000000), !dbg !201 + %4493 = fptrunc <2 x float> %4492 to <2 x bfloat>, !dbg !202 + %4494 = insertelement <2 x float> poison, float %4412, i64 0, !dbg !201 + %4495 = insertelement <2 x float> %4494, float %4413, i64 1, !dbg !201 + %4496 = fmul <2 x float> %4495, splat (float 0x3FB6A09E60000000), !dbg !201 + %4497 = fptrunc <2 x float> %4496 to <2 x bfloat>, !dbg !202 + %4498 = insertelement <2 x float> poison, float %4414, i64 0, !dbg !201 + %4499 = insertelement <2 x float> %4498, float %4415, i64 1, !dbg !201 + %4500 = fmul <2 x float> %4499, splat (float 0x3FB6A09E60000000), !dbg !201 + %4501 = fptrunc <2 x float> %4500 to <2 x bfloat>, !dbg !202 + %4502 = insertelement <2 x float> poison, float %4416, i64 0, !dbg !201 + %4503 = insertelement <2 x float> %4502, float %4417, i64 1, !dbg !201 + %4504 = fmul <2 x float> %4503, splat (float 0x3FB6A09E60000000), !dbg !201 + %4505 = fptrunc <2 x float> %4504 to <2 x bfloat>, !dbg !202 + %4506 = insertelement <2 x float> poison, float %4418, i64 0, !dbg !201 + %4507 = insertelement <2 x float> %4506, float %4419, i64 1, !dbg !201 + %4508 = fmul <2 x float> %4507, splat (float 0x3FB6A09E60000000), !dbg !201 + %4509 = fptrunc <2 x float> %4508 to <2 x bfloat>, !dbg !202 + %4510 = insertelement <2 x float> poison, float %4420, i64 0, !dbg !201 + %4511 = insertelement <2 x float> %4510, float %4421, i64 1, !dbg !201 + %4512 = fmul <2 x float> %4511, splat (float 0x3FB6A09E60000000), !dbg !201 + %4513 = fptrunc <2 x float> %4512 to <2 x bfloat>, !dbg !202 + %4514 = insertelement <2 x float> poison, float %4422, i64 0, !dbg !201 + %4515 = insertelement <2 x float> %4514, float %4423, i64 1, !dbg !201 + %4516 = fmul <2 x float> %4515, splat (float 0x3FB6A09E60000000), !dbg !201 + %4517 = fptrunc <2 x float> %4516 to <2 x bfloat>, !dbg !202 + %4518 = insertelement <2 x float> poison, float %4424, i64 0, !dbg !201 + %4519 = insertelement <2 x float> %4518, float %4425, i64 1, !dbg !201 + %4520 = fmul <2 x float> %4519, splat (float 0x3FB6A09E60000000), !dbg !201 + %4521 = fptrunc <2 x float> %4520 to <2 x bfloat>, !dbg !202 + %4522 = insertelement <2 x float> poison, float %4426, i64 0, !dbg !201 + %4523 = insertelement <2 x float> %4522, float %4427, i64 1, !dbg !201 + %4524 = fmul <2 x float> %4523, splat (float 0x3FB6A09E60000000), !dbg !201 + %4525 = fptrunc <2 x float> %4524 to <2 x bfloat>, !dbg !202 + %4526 = insertelement <2 x float> poison, float %4428, i64 0, !dbg !201 + %4527 = insertelement <2 x float> %4526, float %4429, i64 1, !dbg !201 + %4528 = fmul <2 x float> %4527, splat (float 0x3FB6A09E60000000), !dbg !201 + %4529 = fptrunc <2 x float> %4528 to <2 x bfloat>, !dbg !202 + %4530 = insertelement <2 x float> poison, float %4430, i64 0, !dbg !201 + %4531 = insertelement <2 x float> %4530, float %4431, i64 1, !dbg !201 + %4532 = fmul <2 x float> %4531, splat (float 0x3FB6A09E60000000), !dbg !201 + %4533 = fptrunc <2 x float> %4532 to <2 x bfloat>, !dbg !202 + %4534 = insertelement <2 x float> poison, float %4432, i64 0, !dbg !201 + %4535 = insertelement <2 x float> %4534, float %4433, i64 1, !dbg !201 + %4536 = fmul <2 x float> %4535, splat (float 0x3FB6A09E60000000), !dbg !201 + %4537 = fptrunc <2 x float> %4536 to <2 x bfloat>, !dbg !202 + %4538 = insertelement <2 x float> poison, float %4434, i64 0, !dbg !201 + %4539 = insertelement <2 x float> %4538, float %4435, i64 1, !dbg !201 + %4540 = fmul <2 x float> %4539, splat (float 0x3FB6A09E60000000), !dbg !201 + %4541 = fptrunc <2 x float> %4540 to <2 x bfloat>, !dbg !202 + %4542 = insertelement <2 x float> poison, float %4436, i64 0, !dbg !201 + %4543 = insertelement <2 x float> %4542, float %4437, i64 1, !dbg !201 + %4544 = fmul <2 x float> %4543, splat (float 0x3FB6A09E60000000), !dbg !201 + %4545 = fptrunc <2 x float> %4544 to <2 x bfloat>, !dbg !202 + %4546 = insertelement <2 x float> poison, float %4438, i64 0, !dbg !201 + %4547 = insertelement <2 x float> %4546, float %4439, i64 1, !dbg !201 + %4548 = fmul <2 x float> %4547, splat (float 0x3FB6A09E60000000), !dbg !201 + %4549 = fptrunc <2 x float> %4548 to <2 x bfloat>, !dbg !202 + %4550 = insertelement <2 x float> poison, float %4440, i64 0, !dbg !201 + %4551 = insertelement <2 x float> %4550, float %4441, i64 1, !dbg !201 + %4552 = fmul <2 x float> %4551, splat (float 0x3FB6A09E60000000), !dbg !201 + %4553 = fptrunc <2 x float> %4552 to <2 x bfloat>, !dbg !202 + %4554 = insertelement <2 x float> poison, float %4442, i64 0, !dbg !201 + %4555 = insertelement <2 x float> %4554, float %4443, i64 1, !dbg !201 + %4556 = fmul <2 x float> %4555, splat (float 0x3FB6A09E60000000), !dbg !201 + %4557 = fptrunc <2 x float> %4556 to <2 x bfloat>, !dbg !202 + %4558 = insertelement <2 x float> poison, float %4444, i64 0, !dbg !201 + %4559 = insertelement <2 x float> %4558, float %4445, i64 1, !dbg !201 + %4560 = fmul <2 x float> %4559, splat (float 0x3FB6A09E60000000), !dbg !201 + %4561 = fptrunc <2 x float> %4560 to <2 x bfloat>, !dbg !202 + %4562 = insertelement <2 x float> poison, float %4446, i64 0, !dbg !201 + %4563 = insertelement <2 x float> %4562, float %4447, i64 1, !dbg !201 + %4564 = fmul <2 x float> %4563, splat (float 0x3FB6A09E60000000), !dbg !201 + %4565 = fptrunc <2 x float> %4564 to <2 x bfloat>, !dbg !202 + %4566 = insertelement <2 x float> poison, float %4448, i64 0, !dbg !201 + %4567 = insertelement <2 x float> %4566, float %4449, i64 1, !dbg !201 + %4568 = fmul <2 x float> %4567, splat (float 0x3FB6A09E60000000), !dbg !201 + %4569 = fptrunc <2 x float> %4568 to <2 x bfloat>, !dbg !202 + %4570 = insertelement <2 x float> poison, float %4450, i64 0, !dbg !201 + %4571 = insertelement <2 x float> %4570, float %4451, i64 1, !dbg !201 + %4572 = fmul <2 x float> %4571, splat (float 0x3FB6A09E60000000), !dbg !201 + %4573 = fptrunc <2 x float> %4572 to <2 x bfloat>, !dbg !202 + %4574 = insertelement <2 x float> poison, float %4452, i64 0, !dbg !201 + %4575 = insertelement <2 x float> %4574, float %4453, i64 1, !dbg !201 + %4576 = fmul <2 x float> %4575, splat (float 0x3FB6A09E60000000), !dbg !201 + %4577 = fptrunc <2 x float> %4576 to <2 x bfloat>, !dbg !202 + %4578 = insertelement <2 x float> poison, float %4454, i64 0, !dbg !201 + %4579 = insertelement <2 x float> %4578, float %4455, i64 1, !dbg !201 + %4580 = fmul <2 x float> %4579, splat (float 0x3FB6A09E60000000), !dbg !201 + %4581 = fptrunc <2 x float> %4580 to <2 x bfloat>, !dbg !202 + %4582 = insertelement <2 x float> poison, float %4456, i64 0, !dbg !201 + %4583 = insertelement <2 x float> %4582, float %4457, i64 1, !dbg !201 + %4584 = fmul <2 x float> %4583, splat (float 0x3FB6A09E60000000), !dbg !201 + %4585 = fptrunc <2 x float> %4584 to <2 x bfloat>, !dbg !202 + %4586 = shl nuw nsw i32 %374, 13, !dbg !202 + %4587 = shl nuw nsw i32 %44, 5, !dbg !202 + %4588 = and i32 %4587, 7264, !dbg !202 + %4589 = and i32 %44, 24, !dbg !202 + %4590 = shl nuw nsw i32 %4589, 4, !dbg !202 + %4591 = shl nuw nsw i32 %44, 2, !dbg !202 + %4592 = and i32 %4591, 16, !dbg !202 + %4593 = or disjoint i32 %4586, %4592, !dbg !202 + %4594 = or disjoint i32 %4588, %4590, !dbg !202 + %4595 = or disjoint i32 %4593, %4594, !dbg !202 + %4596 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4595, !dbg !202 + %4597 = bitcast <2 x bfloat> %4461 to i32, !dbg !202 + %4598 = bitcast <2 x bfloat> %4469 to i32, !dbg !202 + %4599 = bitcast <2 x bfloat> %4477 to i32, !dbg !202 + %4600 = bitcast <2 x bfloat> %4485 to i32, !dbg !202 + %4601 = insertelement <4 x i32> poison, i32 %4597, i64 0, !dbg !202 + %4602 = insertelement <4 x i32> %4601, i32 %4598, i64 1, !dbg !202 + %4603 = insertelement <4 x i32> %4602, i32 %4599, i64 2, !dbg !202 + %4604 = insertelement <4 x i32> %4603, i32 %4600, i64 3, !dbg !202 + store <4 x i32> %4604, ptr addrspace(3) %4596, align 16, !dbg !202 + %4605 = getelementptr inbounds nuw i8, ptr addrspace(3) %4596, i32 512, !dbg !202 + %4606 = bitcast <2 x bfloat> %4465 to i32, !dbg !202 + %4607 = bitcast <2 x bfloat> %4473 to i32, !dbg !202 + %4608 = bitcast <2 x bfloat> %4481 to i32, !dbg !202 + %4609 = bitcast <2 x bfloat> %4489 to i32, !dbg !202 + %4610 = insertelement <4 x i32> poison, i32 %4606, i64 0, !dbg !202 + %4611 = insertelement <4 x i32> %4610, i32 %4607, i64 1, !dbg !202 + %4612 = insertelement <4 x i32> %4611, i32 %4608, i64 2, !dbg !202 + %4613 = insertelement <4 x i32> %4612, i32 %4609, i64 3, !dbg !202 + store <4 x i32> %4613, ptr addrspace(3) %4605, align 16, !dbg !202 + %4614 = xor i32 %4595, 32, !dbg !202 + %4615 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4614, !dbg !202 + %4616 = bitcast <2 x bfloat> %4493 to i32, !dbg !202 + %4617 = bitcast <2 x bfloat> %4501 to i32, !dbg !202 + %4618 = bitcast <2 x bfloat> %4509 to i32, !dbg !202 + %4619 = bitcast <2 x bfloat> %4517 to i32, !dbg !202 + %4620 = insertelement <4 x i32> poison, i32 %4616, i64 0, !dbg !202 + %4621 = insertelement <4 x i32> %4620, i32 %4617, i64 1, !dbg !202 + %4622 = insertelement <4 x i32> %4621, i32 %4618, i64 2, !dbg !202 + %4623 = insertelement <4 x i32> %4622, i32 %4619, i64 3, !dbg !202 + store <4 x i32> %4623, ptr addrspace(3) %4615, align 16, !dbg !202 + %4624 = getelementptr inbounds nuw i8, ptr addrspace(3) %4615, i32 512, !dbg !202 + %4625 = bitcast <2 x bfloat> %4497 to i32, !dbg !202 + %4626 = bitcast <2 x bfloat> %4505 to i32, !dbg !202 + %4627 = bitcast <2 x bfloat> %4513 to i32, !dbg !202 + %4628 = bitcast <2 x bfloat> %4521 to i32, !dbg !202 + %4629 = insertelement <4 x i32> poison, i32 %4625, i64 0, !dbg !202 + %4630 = insertelement <4 x i32> %4629, i32 %4626, i64 1, !dbg !202 + %4631 = insertelement <4 x i32> %4630, i32 %4627, i64 2, !dbg !202 + %4632 = insertelement <4 x i32> %4631, i32 %4628, i64 3, !dbg !202 + store <4 x i32> %4632, ptr addrspace(3) %4624, align 16, !dbg !202 + %4633 = xor i32 %4595, 64, !dbg !202 + %4634 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4633, !dbg !202 + %4635 = bitcast <2 x bfloat> %4525 to i32, !dbg !202 + %4636 = bitcast <2 x bfloat> %4533 to i32, !dbg !202 + %4637 = bitcast <2 x bfloat> %4541 to i32, !dbg !202 + %4638 = bitcast <2 x bfloat> %4549 to i32, !dbg !202 + %4639 = insertelement <4 x i32> poison, i32 %4635, i64 0, !dbg !202 + %4640 = insertelement <4 x i32> %4639, i32 %4636, i64 1, !dbg !202 + %4641 = insertelement <4 x i32> %4640, i32 %4637, i64 2, !dbg !202 + %4642 = insertelement <4 x i32> %4641, i32 %4638, i64 3, !dbg !202 + store <4 x i32> %4642, ptr addrspace(3) %4634, align 16, !dbg !202 + %4643 = getelementptr inbounds nuw i8, ptr addrspace(3) %4634, i32 512, !dbg !202 + %4644 = bitcast <2 x bfloat> %4529 to i32, !dbg !202 + %4645 = bitcast <2 x bfloat> %4537 to i32, !dbg !202 + %4646 = bitcast <2 x bfloat> %4545 to i32, !dbg !202 + %4647 = bitcast <2 x bfloat> %4553 to i32, !dbg !202 + %4648 = insertelement <4 x i32> poison, i32 %4644, i64 0, !dbg !202 + %4649 = insertelement <4 x i32> %4648, i32 %4645, i64 1, !dbg !202 + %4650 = insertelement <4 x i32> %4649, i32 %4646, i64 2, !dbg !202 + %4651 = insertelement <4 x i32> %4650, i32 %4647, i64 3, !dbg !202 + store <4 x i32> %4651, ptr addrspace(3) %4643, align 16, !dbg !202 + %4652 = xor i32 %4595, 96, !dbg !202 + %4653 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4652, !dbg !202 + %4654 = bitcast <2 x bfloat> %4557 to i32, !dbg !202 + %4655 = bitcast <2 x bfloat> %4565 to i32, !dbg !202 + %4656 = bitcast <2 x bfloat> %4573 to i32, !dbg !202 + %4657 = bitcast <2 x bfloat> %4581 to i32, !dbg !202 + %4658 = insertelement <4 x i32> poison, i32 %4654, i64 0, !dbg !202 + %4659 = insertelement <4 x i32> %4658, i32 %4655, i64 1, !dbg !202 + %4660 = insertelement <4 x i32> %4659, i32 %4656, i64 2, !dbg !202 + %4661 = insertelement <4 x i32> %4660, i32 %4657, i64 3, !dbg !202 + store <4 x i32> %4661, ptr addrspace(3) %4653, align 16, !dbg !202 + %4662 = getelementptr inbounds nuw i8, ptr addrspace(3) %4653, i32 512, !dbg !202 + %4663 = bitcast <2 x bfloat> %4561 to i32, !dbg !202 + %4664 = bitcast <2 x bfloat> %4569 to i32, !dbg !202 + %4665 = bitcast <2 x bfloat> %4577 to i32, !dbg !202 + %4666 = bitcast <2 x bfloat> %4585 to i32, !dbg !202 + %4667 = insertelement <4 x i32> poison, i32 %4663, i64 0, !dbg !202 + %4668 = insertelement <4 x i32> %4667, i32 %4664, i64 1, !dbg !202 + %4669 = insertelement <4 x i32> %4668, i32 %4665, i64 2, !dbg !202 + %4670 = insertelement <4 x i32> %4669, i32 %4666, i64 3, !dbg !202 + store <4 x i32> %4670, ptr addrspace(3) %4662, align 16, !dbg !202 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !202 + %4671 = shl nuw nsw i32 %4589, 10, !dbg !202 + %4672 = shl nuw nsw i32 %374, 5, !dbg !202 + %4673 = and i32 %4591, 1008, !dbg !202 + %4674 = or disjoint i32 %4671, %4672, !dbg !202 + %4675 = xor i32 %4674, %4673, !dbg !202 + %4676 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %4675, !dbg !202 + %4677 = ptrtoint ptr addrspace(3) %4676 to i32, !dbg !202 + %4678 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4677) #3, !dbg !202 + %4679 = extractvalue { i32, i32, i32, i32 } %4678, 0, !dbg !202 + %4680 = extractvalue { i32, i32, i32, i32 } %4678, 1, !dbg !202 + %4681 = extractvalue { i32, i32, i32, i32 } %4678, 2, !dbg !202 + %4682 = extractvalue { i32, i32, i32, i32 } %4678, 3, !dbg !202 + %4683 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 1024, !dbg !202 + %4684 = ptrtoint ptr addrspace(3) %4683 to i32, !dbg !202 + %4685 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4684) #3, !dbg !202 + %4686 = extractvalue { i32, i32, i32, i32 } %4685, 0, !dbg !202 + %4687 = extractvalue { i32, i32, i32, i32 } %4685, 1, !dbg !202 + %4688 = extractvalue { i32, i32, i32, i32 } %4685, 2, !dbg !202 + %4689 = extractvalue { i32, i32, i32, i32 } %4685, 3, !dbg !202 + %4690 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 2048, !dbg !202 + %4691 = ptrtoint ptr addrspace(3) %4690 to i32, !dbg !202 + %4692 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4691) #3, !dbg !202 + %4693 = extractvalue { i32, i32, i32, i32 } %4692, 0, !dbg !202 + %4694 = extractvalue { i32, i32, i32, i32 } %4692, 1, !dbg !202 + %4695 = extractvalue { i32, i32, i32, i32 } %4692, 2, !dbg !202 + %4696 = extractvalue { i32, i32, i32, i32 } %4692, 3, !dbg !202 + %4697 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 3072, !dbg !202 + %4698 = ptrtoint ptr addrspace(3) %4697 to i32, !dbg !202 + %4699 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4698) #3, !dbg !202 + %4700 = extractvalue { i32, i32, i32, i32 } %4699, 0, !dbg !202 + %4701 = extractvalue { i32, i32, i32, i32 } %4699, 1, !dbg !202 + %4702 = extractvalue { i32, i32, i32, i32 } %4699, 2, !dbg !202 + %4703 = extractvalue { i32, i32, i32, i32 } %4699, 3, !dbg !202 + %4704 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 4096, !dbg !202 + %4705 = ptrtoint ptr addrspace(3) %4704 to i32, !dbg !202 + %4706 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4705) #3, !dbg !202 + %4707 = extractvalue { i32, i32, i32, i32 } %4706, 0, !dbg !202 + %4708 = extractvalue { i32, i32, i32, i32 } %4706, 1, !dbg !202 + %4709 = extractvalue { i32, i32, i32, i32 } %4706, 2, !dbg !202 + %4710 = extractvalue { i32, i32, i32, i32 } %4706, 3, !dbg !202 + %4711 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 5120, !dbg !202 + %4712 = ptrtoint ptr addrspace(3) %4711 to i32, !dbg !202 + %4713 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4712) #3, !dbg !202 + %4714 = extractvalue { i32, i32, i32, i32 } %4713, 0, !dbg !202 + %4715 = extractvalue { i32, i32, i32, i32 } %4713, 1, !dbg !202 + %4716 = extractvalue { i32, i32, i32, i32 } %4713, 2, !dbg !202 + %4717 = extractvalue { i32, i32, i32, i32 } %4713, 3, !dbg !202 + %4718 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 6144, !dbg !202 + %4719 = ptrtoint ptr addrspace(3) %4718 to i32, !dbg !202 + %4720 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4719) #3, !dbg !202 + %4721 = extractvalue { i32, i32, i32, i32 } %4720, 0, !dbg !202 + %4722 = extractvalue { i32, i32, i32, i32 } %4720, 1, !dbg !202 + %4723 = extractvalue { i32, i32, i32, i32 } %4720, 2, !dbg !202 + %4724 = extractvalue { i32, i32, i32, i32 } %4720, 3, !dbg !202 + %4725 = getelementptr inbounds nuw i8, ptr addrspace(3) %4676, i32 7168, !dbg !202 + %4726 = ptrtoint ptr addrspace(3) %4725 to i32, !dbg !202 + %4727 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %4726) #3, !dbg !202 + %4728 = extractvalue { i32, i32, i32, i32 } %4727, 0, !dbg !202 + %4729 = extractvalue { i32, i32, i32, i32 } %4727, 1, !dbg !202 + %4730 = extractvalue { i32, i32, i32, i32 } %4727, 2, !dbg !202 + %4731 = extractvalue { i32, i32, i32, i32 } %4727, 3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4679, i32 %4680, i32 %4681, i32 %4682, ptr addrspace(1) %4386, i1 %132) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4686, i32 %4687, i32 %4688, i32 %4689, ptr addrspace(1) %4387, i1 %133) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4693, i32 %4694, i32 %4695, i32 %4696, ptr addrspace(1) %4388, i1 %134) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4700, i32 %4701, i32 %4702, i32 %4703, ptr addrspace(1) %4389, i1 %135) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4707, i32 %4708, i32 %4709, i32 %4710, ptr addrspace(1) %4390, i1 %136) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4714, i32 %4715, i32 %4716, i32 %4717, ptr addrspace(1) %4391, i1 %137) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4721, i32 %4722, i32 %4723, i32 %4724, ptr addrspace(1) %4392, i1 %138) #3, !dbg !202 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %4728, i32 %4729, i32 %4730, i32 %4731, ptr addrspace(1) %4393, i1 %139) #3, !dbg !202 + br label %11615, !dbg !35 + +4732: ; preds = %21 + %4733 = shl nuw nsw i32 %30, 7, !dbg !203 + %4734 = or disjoint i32 %47, %4733, !dbg !204 + %4735 = or disjoint i32 %48, %4733, !dbg !204 + %4736 = or disjoint i32 %49, %4733, !dbg !204 + %4737 = or disjoint i32 %50, %4733, !dbg !204 + %4738 = or disjoint i32 %51, %4733, !dbg !204 + %4739 = or disjoint i32 %52, %4733, !dbg !204 + %4740 = or disjoint i32 %53, %4733, !dbg !204 + %4741 = or disjoint i32 %54, %4733, !dbg !204 + %4742 = shl i32 %4734, 10, !dbg !205 + %4743 = shl i32 %4735, 10, !dbg !205 + %4744 = shl i32 %4736, 10, !dbg !205 + %4745 = shl i32 %4737, 10, !dbg !205 + %4746 = shl i32 %4738, 10, !dbg !205 + %4747 = shl i32 %4739, 10, !dbg !205 + %4748 = shl i32 %4740, 10, !dbg !205 + %4749 = shl i32 %4741, 10, !dbg !205 + %4750 = sext i32 %4742 to i64, !dbg !207 + %4751 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4750, !dbg !207 + %4752 = sext i32 %4743 to i64, !dbg !207 + %4753 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4752, !dbg !207 + %4754 = sext i32 %4744 to i64, !dbg !207 + %4755 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4754, !dbg !207 + %4756 = sext i32 %4745 to i64, !dbg !207 + %4757 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4756, !dbg !207 + %4758 = sext i32 %4746 to i64, !dbg !207 + %4759 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4758, !dbg !207 + %4760 = sext i32 %4747 to i64, !dbg !207 + %4761 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4760, !dbg !207 + %4762 = sext i32 %4748 to i64, !dbg !207 + %4763 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4762, !dbg !207 + %4764 = sext i32 %4749 to i64, !dbg !207 + %4765 = getelementptr bfloat, ptr addrspace(1) %41, i64 %4764, !dbg !207 + %4766 = shl nuw nsw i32 %44, 3, !dbg !208 + %4767 = and i32 %4766, 120, !dbg !208 + %4768 = zext nneg i32 %4767 to i64, !dbg !209 + %4769 = getelementptr bfloat, ptr addrspace(1) %4751, i64 %4768, !dbg !209 + %4770 = getelementptr bfloat, ptr addrspace(1) %4753, i64 %4768, !dbg !209 + %4771 = getelementptr bfloat, ptr addrspace(1) %4755, i64 %4768, !dbg !209 + %4772 = getelementptr bfloat, ptr addrspace(1) %4757, i64 %4768, !dbg !209 + %4773 = getelementptr bfloat, ptr addrspace(1) %4759, i64 %4768, !dbg !209 + %4774 = getelementptr bfloat, ptr addrspace(1) %4761, i64 %4768, !dbg !209 + %4775 = getelementptr bfloat, ptr addrspace(1) %4763, i64 %4768, !dbg !209 + %4776 = getelementptr bfloat, ptr addrspace(1) %4765, i64 %4768, !dbg !209 + %4777 = icmp slt i32 %4734, %18, !dbg !210 + %4778 = icmp slt i32 %4735, %18, !dbg !210 + %4779 = icmp slt i32 %4736, %18, !dbg !210 + %4780 = icmp slt i32 %4737, %18, !dbg !210 + %4781 = icmp slt i32 %4738, %18, !dbg !210 + %4782 = icmp slt i32 %4739, %18, !dbg !210 + %4783 = icmp slt i32 %4740, %18, !dbg !210 + %4784 = icmp slt i32 %4741, %18, !dbg !210 + %4785 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4769, i1 %4777) #3, !dbg !211 + %4786 = extractvalue { i32, i32, i32, i32 } %4785, 0, !dbg !211 + %4787 = extractvalue { i32, i32, i32, i32 } %4785, 1, !dbg !211 + %4788 = extractvalue { i32, i32, i32, i32 } %4785, 2, !dbg !211 + %4789 = extractvalue { i32, i32, i32, i32 } %4785, 3, !dbg !211 + %4790 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4770, i1 %4778) #3, !dbg !211 + %4791 = extractvalue { i32, i32, i32, i32 } %4790, 0, !dbg !211 + %4792 = extractvalue { i32, i32, i32, i32 } %4790, 1, !dbg !211 + %4793 = extractvalue { i32, i32, i32, i32 } %4790, 2, !dbg !211 + %4794 = extractvalue { i32, i32, i32, i32 } %4790, 3, !dbg !211 + %4795 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4771, i1 %4779) #3, !dbg !211 + %4796 = extractvalue { i32, i32, i32, i32 } %4795, 0, !dbg !211 + %4797 = extractvalue { i32, i32, i32, i32 } %4795, 1, !dbg !211 + %4798 = extractvalue { i32, i32, i32, i32 } %4795, 2, !dbg !211 + %4799 = extractvalue { i32, i32, i32, i32 } %4795, 3, !dbg !211 + %4800 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4772, i1 %4780) #3, !dbg !211 + %4801 = extractvalue { i32, i32, i32, i32 } %4800, 0, !dbg !211 + %4802 = extractvalue { i32, i32, i32, i32 } %4800, 1, !dbg !211 + %4803 = extractvalue { i32, i32, i32, i32 } %4800, 2, !dbg !211 + %4804 = extractvalue { i32, i32, i32, i32 } %4800, 3, !dbg !211 + %4805 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4773, i1 %4781) #3, !dbg !211 + %4806 = extractvalue { i32, i32, i32, i32 } %4805, 0, !dbg !211 + %4807 = extractvalue { i32, i32, i32, i32 } %4805, 1, !dbg !211 + %4808 = extractvalue { i32, i32, i32, i32 } %4805, 2, !dbg !211 + %4809 = extractvalue { i32, i32, i32, i32 } %4805, 3, !dbg !211 + %4810 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4774, i1 %4782) #3, !dbg !211 + %4811 = extractvalue { i32, i32, i32, i32 } %4810, 0, !dbg !211 + %4812 = extractvalue { i32, i32, i32, i32 } %4810, 1, !dbg !211 + %4813 = extractvalue { i32, i32, i32, i32 } %4810, 2, !dbg !211 + %4814 = extractvalue { i32, i32, i32, i32 } %4810, 3, !dbg !211 + %4815 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4775, i1 %4783) #3, !dbg !211 + %4816 = extractvalue { i32, i32, i32, i32 } %4815, 0, !dbg !211 + %4817 = extractvalue { i32, i32, i32, i32 } %4815, 1, !dbg !211 + %4818 = extractvalue { i32, i32, i32, i32 } %4815, 2, !dbg !211 + %4819 = extractvalue { i32, i32, i32, i32 } %4815, 3, !dbg !211 + %4820 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4776, i1 %4784) #3, !dbg !211 + %4821 = extractvalue { i32, i32, i32, i32 } %4820, 0, !dbg !211 + %4822 = extractvalue { i32, i32, i32, i32 } %4820, 1, !dbg !211 + %4823 = extractvalue { i32, i32, i32, i32 } %4820, 2, !dbg !211 + %4824 = extractvalue { i32, i32, i32, i32 } %4820, 3, !dbg !211 + %4825 = shl nuw nsw i32 %44, 4, !dbg !211 + %4826 = and i32 %4825, 112, !dbg !211 + %4827 = shl nuw nsw i32 %46, 3, !dbg !211 + %4828 = and i32 %44, 112, !dbg !211 + %4829 = and i32 %44, 8, !dbg !211 + %4830 = shl nuw nsw i32 %4829, 11, !dbg !211 + %4831 = or disjoint i32 %4826, %4827, !dbg !211 + %4832 = xor i32 %4831, %4828, !dbg !211 + %4833 = or disjoint i32 %4832, %4830, !dbg !211 + %4834 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4833, !dbg !211 + %4835 = insertelement <4 x i32> poison, i32 %4786, i64 0, !dbg !211 + %4836 = insertelement <4 x i32> %4835, i32 %4787, i64 1, !dbg !211 + %4837 = insertelement <4 x i32> %4836, i32 %4788, i64 2, !dbg !211 + %4838 = insertelement <4 x i32> %4837, i32 %4789, i64 3, !dbg !211 + store <4 x i32> %4838, ptr addrspace(3) %4834, align 16, !dbg !211 + %4839 = or disjoint i32 %4833, 2048, !dbg !211 + %4840 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4839, !dbg !211 + %4841 = insertelement <4 x i32> poison, i32 %4791, i64 0, !dbg !211 + %4842 = insertelement <4 x i32> %4841, i32 %4792, i64 1, !dbg !211 + %4843 = insertelement <4 x i32> %4842, i32 %4793, i64 2, !dbg !211 + %4844 = insertelement <4 x i32> %4843, i32 %4794, i64 3, !dbg !211 + store <4 x i32> %4844, ptr addrspace(3) %4840, align 16, !dbg !211 + %4845 = or disjoint i32 %4833, 4096, !dbg !211 + %4846 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4845, !dbg !211 + %4847 = insertelement <4 x i32> poison, i32 %4796, i64 0, !dbg !211 + %4848 = insertelement <4 x i32> %4847, i32 %4797, i64 1, !dbg !211 + %4849 = insertelement <4 x i32> %4848, i32 %4798, i64 2, !dbg !211 + %4850 = insertelement <4 x i32> %4849, i32 %4799, i64 3, !dbg !211 + store <4 x i32> %4850, ptr addrspace(3) %4846, align 16, !dbg !211 + %4851 = or disjoint i32 %4833, 6144, !dbg !211 + %4852 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4851, !dbg !211 + %4853 = insertelement <4 x i32> poison, i32 %4801, i64 0, !dbg !211 + %4854 = insertelement <4 x i32> %4853, i32 %4802, i64 1, !dbg !211 + %4855 = insertelement <4 x i32> %4854, i32 %4803, i64 2, !dbg !211 + %4856 = insertelement <4 x i32> %4855, i32 %4804, i64 3, !dbg !211 + store <4 x i32> %4856, ptr addrspace(3) %4852, align 16, !dbg !211 + %4857 = or disjoint i32 %4833, 8192, !dbg !211 + %4858 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4857, !dbg !211 + %4859 = insertelement <4 x i32> poison, i32 %4806, i64 0, !dbg !211 + %4860 = insertelement <4 x i32> %4859, i32 %4807, i64 1, !dbg !211 + %4861 = insertelement <4 x i32> %4860, i32 %4808, i64 2, !dbg !211 + %4862 = insertelement <4 x i32> %4861, i32 %4809, i64 3, !dbg !211 + store <4 x i32> %4862, ptr addrspace(3) %4858, align 16, !dbg !211 + %4863 = or disjoint i32 %4833, 10240, !dbg !211 + %4864 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4863, !dbg !211 + %4865 = insertelement <4 x i32> poison, i32 %4811, i64 0, !dbg !211 + %4866 = insertelement <4 x i32> %4865, i32 %4812, i64 1, !dbg !211 + %4867 = insertelement <4 x i32> %4866, i32 %4813, i64 2, !dbg !211 + %4868 = insertelement <4 x i32> %4867, i32 %4814, i64 3, !dbg !211 + store <4 x i32> %4868, ptr addrspace(3) %4864, align 16, !dbg !211 + %4869 = or disjoint i32 %4833, 12288, !dbg !211 + %4870 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4869, !dbg !211 + %4871 = insertelement <4 x i32> poison, i32 %4816, i64 0, !dbg !211 + %4872 = insertelement <4 x i32> %4871, i32 %4817, i64 1, !dbg !211 + %4873 = insertelement <4 x i32> %4872, i32 %4818, i64 2, !dbg !211 + %4874 = insertelement <4 x i32> %4873, i32 %4819, i64 3, !dbg !211 + store <4 x i32> %4874, ptr addrspace(3) %4870, align 16, !dbg !211 + %4875 = or disjoint i32 %4833, 14336, !dbg !211 + %4876 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 %4875, !dbg !211 + %4877 = insertelement <4 x i32> poison, i32 %4821, i64 0, !dbg !211 + %4878 = insertelement <4 x i32> %4877, i32 %4822, i64 1, !dbg !211 + %4879 = insertelement <4 x i32> %4878, i32 %4823, i64 2, !dbg !211 + %4880 = insertelement <4 x i32> %4879, i32 %4824, i64 3, !dbg !211 + store <4 x i32> %4880, ptr addrspace(3) %4876, align 16, !dbg !211 + %4881 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4750, !dbg !212 + %4882 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4752, !dbg !212 + %4883 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4754, !dbg !212 + %4884 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4756, !dbg !212 + %4885 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4758, !dbg !212 + %4886 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4760, !dbg !212 + %4887 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4762, !dbg !212 + %4888 = getelementptr bfloat, ptr addrspace(1) %42, i64 %4764, !dbg !212 + %4889 = getelementptr bfloat, ptr addrspace(1) %4881, i64 %4768, !dbg !214 + %4890 = getelementptr bfloat, ptr addrspace(1) %4882, i64 %4768, !dbg !214 + %4891 = getelementptr bfloat, ptr addrspace(1) %4883, i64 %4768, !dbg !214 + %4892 = getelementptr bfloat, ptr addrspace(1) %4884, i64 %4768, !dbg !214 + %4893 = getelementptr bfloat, ptr addrspace(1) %4885, i64 %4768, !dbg !214 + %4894 = getelementptr bfloat, ptr addrspace(1) %4886, i64 %4768, !dbg !214 + %4895 = getelementptr bfloat, ptr addrspace(1) %4887, i64 %4768, !dbg !214 + %4896 = getelementptr bfloat, ptr addrspace(1) %4888, i64 %4768, !dbg !214 + %4897 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4889, i1 %4777) #3, !dbg !215 + %4898 = extractvalue { i32, i32, i32, i32 } %4897, 0, !dbg !215 + %4899 = extractvalue { i32, i32, i32, i32 } %4897, 1, !dbg !215 + %4900 = extractvalue { i32, i32, i32, i32 } %4897, 2, !dbg !215 + %4901 = extractvalue { i32, i32, i32, i32 } %4897, 3, !dbg !215 + %4902 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4890, i1 %4778) #3, !dbg !215 + %4903 = extractvalue { i32, i32, i32, i32 } %4902, 0, !dbg !215 + %4904 = extractvalue { i32, i32, i32, i32 } %4902, 1, !dbg !215 + %4905 = extractvalue { i32, i32, i32, i32 } %4902, 2, !dbg !215 + %4906 = extractvalue { i32, i32, i32, i32 } %4902, 3, !dbg !215 + %4907 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4891, i1 %4779) #3, !dbg !215 + %4908 = extractvalue { i32, i32, i32, i32 } %4907, 0, !dbg !215 + %4909 = extractvalue { i32, i32, i32, i32 } %4907, 1, !dbg !215 + %4910 = extractvalue { i32, i32, i32, i32 } %4907, 2, !dbg !215 + %4911 = extractvalue { i32, i32, i32, i32 } %4907, 3, !dbg !215 + %4912 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4892, i1 %4780) #3, !dbg !215 + %4913 = extractvalue { i32, i32, i32, i32 } %4912, 0, !dbg !215 + %4914 = extractvalue { i32, i32, i32, i32 } %4912, 1, !dbg !215 + %4915 = extractvalue { i32, i32, i32, i32 } %4912, 2, !dbg !215 + %4916 = extractvalue { i32, i32, i32, i32 } %4912, 3, !dbg !215 + %4917 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4893, i1 %4781) #3, !dbg !215 + %4918 = extractvalue { i32, i32, i32, i32 } %4917, 0, !dbg !215 + %4919 = extractvalue { i32, i32, i32, i32 } %4917, 1, !dbg !215 + %4920 = extractvalue { i32, i32, i32, i32 } %4917, 2, !dbg !215 + %4921 = extractvalue { i32, i32, i32, i32 } %4917, 3, !dbg !215 + %4922 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4894, i1 %4782) #3, !dbg !215 + %4923 = extractvalue { i32, i32, i32, i32 } %4922, 0, !dbg !215 + %4924 = extractvalue { i32, i32, i32, i32 } %4922, 1, !dbg !215 + %4925 = extractvalue { i32, i32, i32, i32 } %4922, 2, !dbg !215 + %4926 = extractvalue { i32, i32, i32, i32 } %4922, 3, !dbg !215 + %4927 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4895, i1 %4783) #3, !dbg !215 + %4928 = extractvalue { i32, i32, i32, i32 } %4927, 0, !dbg !215 + %4929 = extractvalue { i32, i32, i32, i32 } %4927, 1, !dbg !215 + %4930 = extractvalue { i32, i32, i32, i32 } %4927, 2, !dbg !215 + %4931 = extractvalue { i32, i32, i32, i32 } %4927, 3, !dbg !215 + %4932 = tail call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, $4;\0A\09mov.u32 $1, $5;\0A\09mov.u32 $2, $6;\0A\09mov.u32 $3, $7;\0A\09@$9 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $8 + 0 ];", "=r,=r,=r,=r,r,r,r,r,l,b"(i32 0, i32 0, i32 0, i32 0, ptr addrspace(1) %4896, i1 %4784) #3, !dbg !215 + %4933 = extractvalue { i32, i32, i32, i32 } %4932, 0, !dbg !215 + %4934 = extractvalue { i32, i32, i32, i32 } %4932, 1, !dbg !215 + %4935 = extractvalue { i32, i32, i32, i32 } %4932, 2, !dbg !215 + %4936 = extractvalue { i32, i32, i32, i32 } %4932, 3, !dbg !215 + %4937 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4833, !dbg !215 + %4938 = insertelement <4 x i32> poison, i32 %4898, i64 0, !dbg !215 + %4939 = insertelement <4 x i32> %4938, i32 %4899, i64 1, !dbg !215 + %4940 = insertelement <4 x i32> %4939, i32 %4900, i64 2, !dbg !215 + %4941 = insertelement <4 x i32> %4940, i32 %4901, i64 3, !dbg !215 + store <4 x i32> %4941, ptr addrspace(3) %4937, align 16, !dbg !215 + %4942 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4839, !dbg !215 + %4943 = insertelement <4 x i32> poison, i32 %4903, i64 0, !dbg !215 + %4944 = insertelement <4 x i32> %4943, i32 %4904, i64 1, !dbg !215 + %4945 = insertelement <4 x i32> %4944, i32 %4905, i64 2, !dbg !215 + %4946 = insertelement <4 x i32> %4945, i32 %4906, i64 3, !dbg !215 + store <4 x i32> %4946, ptr addrspace(3) %4942, align 16, !dbg !215 + %4947 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4845, !dbg !215 + %4948 = insertelement <4 x i32> poison, i32 %4908, i64 0, !dbg !215 + %4949 = insertelement <4 x i32> %4948, i32 %4909, i64 1, !dbg !215 + %4950 = insertelement <4 x i32> %4949, i32 %4910, i64 2, !dbg !215 + %4951 = insertelement <4 x i32> %4950, i32 %4911, i64 3, !dbg !215 + store <4 x i32> %4951, ptr addrspace(3) %4947, align 16, !dbg !215 + %4952 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4851, !dbg !215 + %4953 = insertelement <4 x i32> poison, i32 %4913, i64 0, !dbg !215 + %4954 = insertelement <4 x i32> %4953, i32 %4914, i64 1, !dbg !215 + %4955 = insertelement <4 x i32> %4954, i32 %4915, i64 2, !dbg !215 + %4956 = insertelement <4 x i32> %4955, i32 %4916, i64 3, !dbg !215 + store <4 x i32> %4956, ptr addrspace(3) %4952, align 16, !dbg !215 + %4957 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4857, !dbg !215 + %4958 = insertelement <4 x i32> poison, i32 %4918, i64 0, !dbg !215 + %4959 = insertelement <4 x i32> %4958, i32 %4919, i64 1, !dbg !215 + %4960 = insertelement <4 x i32> %4959, i32 %4920, i64 2, !dbg !215 + %4961 = insertelement <4 x i32> %4960, i32 %4921, i64 3, !dbg !215 + store <4 x i32> %4961, ptr addrspace(3) %4957, align 16, !dbg !215 + %4962 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4863, !dbg !215 + %4963 = insertelement <4 x i32> poison, i32 %4923, i64 0, !dbg !215 + %4964 = insertelement <4 x i32> %4963, i32 %4924, i64 1, !dbg !215 + %4965 = insertelement <4 x i32> %4964, i32 %4925, i64 2, !dbg !215 + %4966 = insertelement <4 x i32> %4965, i32 %4926, i64 3, !dbg !215 + store <4 x i32> %4966, ptr addrspace(3) %4962, align 16, !dbg !215 + %4967 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4869, !dbg !215 + %4968 = insertelement <4 x i32> poison, i32 %4928, i64 0, !dbg !215 + %4969 = insertelement <4 x i32> %4968, i32 %4929, i64 1, !dbg !215 + %4970 = insertelement <4 x i32> %4969, i32 %4930, i64 2, !dbg !215 + %4971 = insertelement <4 x i32> %4970, i32 %4931, i64 3, !dbg !215 + store <4 x i32> %4971, ptr addrspace(3) %4967, align 16, !dbg !215 + %4972 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 %4875, !dbg !215 + %4973 = insertelement <4 x i32> poison, i32 %4933, i64 0, !dbg !215 + %4974 = insertelement <4 x i32> %4973, i32 %4934, i64 1, !dbg !215 + %4975 = insertelement <4 x i32> %4974, i32 %4935, i64 2, !dbg !215 + %4976 = insertelement <4 x i32> %4975, i32 %4936, i64 3, !dbg !215 + store <4 x i32> %4976, ptr addrspace(3) %4972, align 16, !dbg !215 + %4977 = shl nuw nsw i32 %34, 2, !dbg !216 + %4978 = mul i32 %22, %33, !dbg !217 + %4979 = mul i32 %28, %33, !dbg !218 + %4980 = shl nuw nsw i32 %33, 5, !dbg !219 + %4981 = zext nneg i32 %30 to i64, !dbg !220 + %4982 = getelementptr i32, ptr addrspace(1) %11, i64 %4981, !dbg !220 + %4983 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4982, i1 true) #3, !dbg !221 + %4984 = shl i32 %4983, 7, !dbg !222 + %4985 = getelementptr i32, ptr addrspace(1) %10, i64 %4981, !dbg !223 + %4986 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %4985, i1 true) #3, !dbg !224 + %4987 = and i32 %44, 3, !dbg !225 + %4988 = shl nuw nsw i32 %4987, 1, !dbg !225 + %4989 = or disjoint i32 %4988, 1, !dbg !225 + %4990 = or disjoint i32 %4988, 8, !dbg !225 + %4991 = or disjoint i32 %4988, 9, !dbg !225 + %4992 = or disjoint i32 %4984, %47, !dbg !226 + %4993 = or disjoint i32 %4984, %48, !dbg !226 + %4994 = or disjoint i32 %4984, %49, !dbg !226 + %4995 = or disjoint i32 %4984, %50, !dbg !226 + %4996 = shl i32 %4992, 12, !dbg !227 + %4997 = shl i32 %4993, 12, !dbg !227 + %4998 = shl i32 %4994, 12, !dbg !227 + %4999 = shl i32 %4995, 12, !dbg !227 + %5000 = shl i32 %4992, 7, !dbg !229 + %5001 = shl i32 %4993, 7, !dbg !229 + %5002 = shl i32 %4994, 7, !dbg !229 + %5003 = shl i32 %4995, 7, !dbg !229 + %5004 = shl i32 %4986, 1, !dbg !230 + %5005 = add i32 %17, 63, !dbg !231 + %5006 = sdiv i32 %5005, 64, !dbg !232 + %5007 = tail call i32 @llvm.smax.i32(i32 %5006, i32 1), !dbg !233 + %5008 = tail call i32 @llvm.smin.i32(i32 %5004, i32 %5007), !dbg !234 + %5009 = insertelement <2 x i32> poison, i32 %60, i64 0, !dbg !204 + %5010 = insertelement <2 x i32> %5009, i32 %59, i64 1, !dbg !204 + %5011 = insertelement <2 x i32> poison, i32 %4733, i64 0, !dbg !204 + %5012 = shufflevector <2 x i32> %5011, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !204 + %5013 = or disjoint <2 x i32> %5010, %5012, !dbg !204 + %5014 = insertelement <4 x i32> poison, i32 %4988, i64 0, !dbg !225 + %5015 = shufflevector <4 x i32> %5014, <4 x i32> poison, <4 x i32> zeroinitializer, !dbg !225 + %5016 = or disjoint <4 x i32> %5015, , !dbg !225 + %5017 = insertelement <8 x i32> poison, i32 %4988, i64 0, !dbg !225 + %5018 = shufflevector <8 x i32> %5017, <8 x i32> poison, <8 x i32> zeroinitializer, !dbg !225 + %5019 = or disjoint <8 x i32> %5018, , !dbg !225 + %5020 = insertelement <16 x i32> poison, i32 %4984, i64 0, !dbg !226 + %5021 = shufflevector <16 x i32> %5020, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !226 + %5022 = insertelement <16 x i32> poison, i32 %4991, i64 12, !dbg !226 + %5023 = insertelement <16 x i32> %5022, i32 %4990, i64 13, !dbg !226 + %5024 = insertelement <16 x i32> %5023, i32 %4989, i64 14, !dbg !226 + %5025 = insertelement <16 x i32> %5024, i32 %4988, i64 15, !dbg !226 + %5026 = shufflevector <8 x i32> %5019, <8 x i32> poison, <16 x i32> , !dbg !226 + %5027 = shufflevector <16 x i32> %5026, <16 x i32> %5025, <16 x i32> , !dbg !226 + %5028 = shufflevector <4 x i32> %5016, <4 x i32> poison, <16 x i32> , !dbg !226 + %5029 = shufflevector <16 x i32> %5027, <16 x i32> %5028, <16 x i32> , !dbg !226 + %5030 = or disjoint <16 x i32> %5021, %5029, !dbg !226 + %5031 = insertelement <2 x i32> poison, i32 %18, i64 0, !dbg !235 + %5032 = shufflevector <2 x i32> %5031, <2 x i32> poison, <2 x i32> zeroinitializer, !dbg !235 + %5033 = srem <2 x i32> %5013, %5032, !dbg !235 + %5034 = lshr <2 x i32> %5033, splat (i32 4), !dbg !236 + %5035 = shufflevector <2 x i32> %5034, <2 x i32> poison, <32 x i32> , !dbg !236 + %5036 = getelementptr i32, ptr addrspace(1) %15, i64 %4981, !dbg !237 + %5037 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %5036, i1 true) #3, !dbg !238 + %5038 = shl i32 %5037, 7, !dbg !239 + %5039 = getelementptr i32, ptr addrspace(1) %14, i64 %4981, !dbg !240 + %5040 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %5039, i1 true) #3, !dbg !241 + %5041 = or disjoint i32 %5038, %4988, !dbg !242 + %5042 = or disjoint i32 %5038, %4989, !dbg !242 + %5043 = or disjoint i32 %5038, %4990, !dbg !242 + %5044 = or disjoint i32 %5038, %4991, !dbg !242 + %5045 = extractelement <4 x i32> %5016, i64 3, !dbg !242 + %5046 = or disjoint i32 %5038, %5045, !dbg !242 + %5047 = extractelement <4 x i32> %5016, i64 2, !dbg !242 + %5048 = or disjoint i32 %5038, %5047, !dbg !242 + %5049 = extractelement <4 x i32> %5016, i64 1, !dbg !242 + %5050 = or disjoint i32 %5038, %5049, !dbg !242 + %5051 = extractelement <4 x i32> %5016, i64 0, !dbg !242 + %5052 = or disjoint i32 %5038, %5051, !dbg !242 + %5053 = extractelement <8 x i32> %5019, i64 7, !dbg !242 + %5054 = or disjoint i32 %5038, %5053, !dbg !242 + %5055 = extractelement <8 x i32> %5019, i64 6, !dbg !242 + %5056 = or disjoint i32 %5038, %5055, !dbg !242 + %5057 = extractelement <8 x i32> %5019, i64 5, !dbg !242 + %5058 = or disjoint i32 %5038, %5057, !dbg !242 + %5059 = extractelement <8 x i32> %5019, i64 4, !dbg !242 + %5060 = or disjoint i32 %5038, %5059, !dbg !242 + %5061 = extractelement <8 x i32> %5019, i64 3, !dbg !242 + %5062 = or disjoint i32 %5038, %5061, !dbg !242 + %5063 = extractelement <8 x i32> %5019, i64 2, !dbg !242 + %5064 = or disjoint i32 %5038, %5063, !dbg !242 + %5065 = extractelement <8 x i32> %5019, i64 1, !dbg !242 + %5066 = or disjoint i32 %5038, %5065, !dbg !242 + %5067 = extractelement <8 x i32> %5019, i64 0, !dbg !242 + %5068 = or disjoint i32 %5038, %5067, !dbg !242 + %5069 = or disjoint i32 %5038, %47, !dbg !242 + %5070 = or disjoint i32 %5038, %48, !dbg !242 + %5071 = or disjoint i32 %5038, %49, !dbg !242 + %5072 = or disjoint i32 %5038, %50, !dbg !242 + %5073 = shl i32 %5069, 12, !dbg !243 + %5074 = shl i32 %5070, 12, !dbg !243 + %5075 = shl i32 %5071, 12, !dbg !243 + %5076 = shl i32 %5072, 12, !dbg !243 + %5077 = shl i32 %5069, 7, !dbg !245 + %5078 = shl i32 %5070, 7, !dbg !245 + %5079 = shl i32 %5071, 7, !dbg !245 + %5080 = shl i32 %5072, 7, !dbg !245 + %5081 = shl i32 %5040, 1, !dbg !246 + %5082 = tail call i32 @llvm.smin.i32(i32 %5081, i32 %5007), !dbg !247 + tail call void asm sideeffect "fence.proxy.async.shared::cta;", ""() #3, !dbg !248 + %5083 = sext i32 %4996 to i64 + %5084 = sext i32 %4997 to i64 + %5085 = sext i32 %4998 to i64 + %5086 = sext i32 %4999 to i64 + %5087 = sext i32 %5000 to i64 + %5088 = sext i32 %5001 to i64 + %5089 = sext i32 %5002 to i64 + %5090 = sext i32 %5003 to i64 + %5091 = icmp sgt i32 %5004, 0 + %5092 = icmp slt i32 %4992, %17 + %5093 = icmp slt i32 %4993, %17 + %5094 = icmp slt i32 %4994, %17 + %5095 = icmp slt i32 %4995, %17 + %5096 = and i1 %5091, %5092 + %5097 = and i1 %5091, %5093 + %5098 = and i1 %5091, %5094 + %5099 = and i1 %5091, %5095 + %5100 = shl nuw nsw i32 %4829, 10 + %5101 = or disjoint i32 %4832, %5100 + %5102 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5101 + %5103 = select i1 %5096, i32 16, i32 0 + %5104 = or disjoint i32 %5101, 2048 + %5105 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5104 + %5106 = select i1 %5097, i32 16, i32 0 + %5107 = or disjoint i32 %5101, 4096 + %5108 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5107 + %5109 = select i1 %5098, i32 16, i32 0 + %5110 = or disjoint i32 %5101, 6144 + %5111 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %5110 + %5112 = select i1 %5099, i32 16, i32 0 + %5113 = extractelement <16 x i32> %5030, i64 15 + %5114 = icmp slt i32 %5113, %17 + %5115 = extractelement <16 x i32> %5030, i64 14 + %5116 = icmp slt i32 %5115, %17 + %5117 = extractelement <16 x i32> %5030, i64 13 + %5118 = icmp slt i32 %5117, %17 + %5119 = extractelement <16 x i32> %5030, i64 12 + %5120 = icmp slt i32 %5119, %17 + %5121 = extractelement <16 x i32> %5030, i64 11 + %5122 = icmp slt i32 %5121, %17 + %5123 = extractelement <16 x i32> %5030, i64 10 + %5124 = icmp slt i32 %5123, %17 + %5125 = extractelement <16 x i32> %5030, i64 9 + %5126 = icmp slt i32 %5125, %17 + %5127 = extractelement <16 x i32> %5030, i64 8 + %5128 = icmp slt i32 %5127, %17 + %5129 = extractelement <16 x i32> %5030, i64 7 + %5130 = icmp slt i32 %5129, %17 + %5131 = extractelement <16 x i32> %5030, i64 6 + %5132 = icmp slt i32 %5131, %17 + %5133 = extractelement <16 x i32> %5030, i64 5 + %5134 = icmp slt i32 %5133, %17 + %5135 = extractelement <16 x i32> %5030, i64 4 + %5136 = icmp slt i32 %5135, %17 + %5137 = extractelement <16 x i32> %5030, i64 3 + %5138 = icmp slt i32 %5137, %17 + %5139 = extractelement <16 x i32> %5030, i64 2 + %5140 = icmp slt i32 %5139, %17 + %5141 = extractelement <16 x i32> %5030, i64 1 + %5142 = icmp slt i32 %5141, %17 + %5143 = extractelement <16 x i32> %5030, i64 0 + %5144 = icmp slt i32 %5143, %17 + %5145 = sext i32 %5113 to i64 + %5146 = sext i32 %5115 to i64 + %5147 = sext i32 %5117 to i64 + %5148 = sext i32 %5119 to i64 + %5149 = sext i32 %5121 to i64 + %5150 = sext i32 %5123 to i64 + %5151 = sext i32 %5125 to i64 + %5152 = sext i32 %5127 to i64 + %5153 = sext i32 %5129 to i64 + %5154 = sext i32 %5131 to i64 + %5155 = sext i32 %5133 to i64 + %5156 = sext i32 %5135 to i64 + %5157 = sext i32 %5137 to i64 + %5158 = sext i32 %5139 to i64 + %5159 = sext i32 %5141 to i64 + %5160 = sext i32 %5143 to i64 + %5161 = and i1 %5091, %5114 + %5162 = and i1 %5091, %5116 + %5163 = and i1 %5091, %5118 + %5164 = and i1 %5091, %5120 + %5165 = and i1 %5091, %5122 + %5166 = and i1 %5091, %5124 + %5167 = and i1 %5091, %5126 + %5168 = and i1 %5091, %5128 + %5169 = and i1 %5091, %5130 + %5170 = and i1 %5091, %5132 + %5171 = and i1 %5091, %5134 + %5172 = and i1 %5091, %5136 + %5173 = and i1 %5091, %5138 + %5174 = and i1 %5091, %5140 + %5175 = and i1 %5091, %5142 + %5176 = and i1 %5091, %5144 + %5177 = and i32 %44, 252 + %5178 = icmp eq i32 %5177, 0 + %5179 = shl nuw nsw i32 %4987, 3 + %5180 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5179 + %5181 = select i1 %5161, i32 4, i32 0 + %5182 = or disjoint i32 %5179, 4 + %5183 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5182 + %5184 = select i1 %5162, i32 4, i32 0 + %5185 = or disjoint i32 %5179, 32 + %5186 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5185 + %5187 = select i1 %5163, i32 4, i32 0 + %5188 = or disjoint i32 %5179, 36 + %5189 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5188 + %5190 = select i1 %5164, i32 4, i32 0 + %5191 = or disjoint i32 %5179, 64 + %5192 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5191 + %5193 = select i1 %5165, i32 4, i32 0 + %5194 = or disjoint i32 %5179, 68 + %5195 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5194 + %5196 = select i1 %5166, i32 4, i32 0 + %5197 = or disjoint i32 %5179, 96 + %5198 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5197 + %5199 = select i1 %5167, i32 4, i32 0 + %5200 = or disjoint i32 %5179, 100 + %5201 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5200 + %5202 = select i1 %5168, i32 4, i32 0 + %5203 = or disjoint i32 %5179, 128 + %5204 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5203 + %5205 = select i1 %5169, i32 4, i32 0 + %5206 = or disjoint i32 %5179, 132 + %5207 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5206 + %5208 = select i1 %5170, i32 4, i32 0 + %5209 = or disjoint i32 %5179, 160 + %5210 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5209 + %5211 = select i1 %5171, i32 4, i32 0 + %5212 = or disjoint i32 %5179, 164 + %5213 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5212 + %5214 = select i1 %5172, i32 4, i32 0 + %5215 = or disjoint i32 %5179, 192 + %5216 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5215 + %5217 = select i1 %5173, i32 4, i32 0 + %5218 = or disjoint i32 %5179, 196 + %5219 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5218 + %5220 = select i1 %5174, i32 4, i32 0 + %5221 = or disjoint i32 %5179, 224 + %5222 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5221 + %5223 = select i1 %5175, i32 4, i32 0 + %5224 = or disjoint i32 %5179, 228 + %5225 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5224 + %5226 = select i1 %5176, i32 4, i32 0 + %5227 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5101 + %5228 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5104 + %5229 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5107 + %5230 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5110 + %5231 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5179 + %5232 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5182 + %5233 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5185 + %5234 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5188 + %5235 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5191 + %5236 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5194 + %5237 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5197 + %5238 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5200 + %5239 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5203 + %5240 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5206 + %5241 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5209 + %5242 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5212 + %5243 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5215 + %5244 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5218 + %5245 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5221 + %5246 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5224 + %5247 = icmp sgt i32 %5008, 1 + %5248 = or disjoint i32 %5113, 64 + %5249 = or disjoint i32 %5115, 64 + %5250 = or disjoint i32 %5117, 64 + %5251 = or disjoint i32 %5119, 64 + %5252 = or disjoint i32 %5121, 64 + %5253 = or disjoint i32 %5123, 64 + %5254 = or disjoint i32 %5125, 64 + %5255 = or disjoint i32 %5127, 64 + %5256 = or disjoint i32 %5129, 64 + %5257 = or disjoint i32 %5131, 64 + %5258 = or disjoint i32 %5133, 64 + %5259 = or disjoint i32 %5135, 64 + %5260 = or disjoint i32 %5137, 64 + %5261 = or disjoint i32 %5139, 64 + %5262 = or disjoint i32 %5141, 64 + %5263 = or disjoint i32 %5143, 64 + %5264 = or disjoint i32 %4992, 64 + %5265 = or disjoint i32 %4993, 64 + %5266 = or disjoint i32 %4994, 64 + %5267 = or disjoint i32 %4995, 64 + %5268 = icmp slt i32 %5264, %17 + %5269 = icmp slt i32 %5265, %17 + %5270 = icmp slt i32 %5266, %17 + %5271 = icmp slt i32 %5267, %17 + %5272 = and i1 %5247, %5268 + %5273 = and i1 %5247, %5269 + %5274 = and i1 %5247, %5270 + %5275 = and i1 %5247, %5271 + %5276 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5101 + %5277 = select i1 %5272, i32 16, i32 0 + %5278 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5104 + %5279 = select i1 %5273, i32 16, i32 0 + %5280 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5107 + %5281 = select i1 %5274, i32 16, i32 0 + %5282 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 16384), i32 %5110 + %5283 = select i1 %5275, i32 16, i32 0 + %5284 = icmp slt i32 %5248, %17 + %5285 = icmp slt i32 %5249, %17 + %5286 = icmp slt i32 %5250, %17 + %5287 = icmp slt i32 %5251, %17 + %5288 = icmp slt i32 %5252, %17 + %5289 = icmp slt i32 %5253, %17 + %5290 = icmp slt i32 %5254, %17 + %5291 = icmp slt i32 %5255, %17 + %5292 = icmp slt i32 %5256, %17 + %5293 = icmp slt i32 %5257, %17 + %5294 = icmp slt i32 %5258, %17 + %5295 = icmp slt i32 %5259, %17 + %5296 = icmp slt i32 %5260, %17 + %5297 = icmp slt i32 %5261, %17 + %5298 = icmp slt i32 %5262, %17 + %5299 = icmp slt i32 %5263, %17 + %5300 = sext i32 %5248 to i64 + %5301 = sext i32 %5249 to i64 + %5302 = sext i32 %5250 to i64 + %5303 = sext i32 %5251 to i64 + %5304 = sext i32 %5252 to i64 + %5305 = sext i32 %5253 to i64 + %5306 = sext i32 %5254 to i64 + %5307 = sext i32 %5255 to i64 + %5308 = sext i32 %5256 to i64 + %5309 = sext i32 %5257 to i64 + %5310 = sext i32 %5258 to i64 + %5311 = sext i32 %5259 to i64 + %5312 = sext i32 %5260 to i64 + %5313 = sext i32 %5261 to i64 + %5314 = sext i32 %5262 to i64 + %5315 = sext i32 %5263 to i64 + %5316 = and i1 %5247, %5284 + %5317 = and i1 %5247, %5285 + %5318 = and i1 %5247, %5286 + %5319 = and i1 %5247, %5287 + %5320 = and i1 %5247, %5288 + %5321 = and i1 %5247, %5289 + %5322 = and i1 %5247, %5290 + %5323 = and i1 %5247, %5291 + %5324 = and i1 %5247, %5292 + %5325 = and i1 %5247, %5293 + %5326 = and i1 %5247, %5294 + %5327 = and i1 %5247, %5295 + %5328 = and i1 %5247, %5296 + %5329 = and i1 %5247, %5297 + %5330 = and i1 %5247, %5298 + %5331 = and i1 %5247, %5299 + %5332 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5179 + %5333 = select i1 %5316, i32 4, i32 0 + %5334 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5182 + %5335 = select i1 %5317, i32 4, i32 0 + %5336 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5185 + %5337 = select i1 %5318, i32 4, i32 0 + %5338 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5188 + %5339 = select i1 %5319, i32 4, i32 0 + %5340 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5191 + %5341 = select i1 %5320, i32 4, i32 0 + %5342 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5194 + %5343 = select i1 %5321, i32 4, i32 0 + %5344 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5197 + %5345 = select i1 %5322, i32 4, i32 0 + %5346 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5200 + %5347 = select i1 %5323, i32 4, i32 0 + %5348 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5203 + %5349 = select i1 %5324, i32 4, i32 0 + %5350 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5206 + %5351 = select i1 %5325, i32 4, i32 0 + %5352 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5209 + %5353 = select i1 %5326, i32 4, i32 0 + %5354 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5212 + %5355 = select i1 %5327, i32 4, i32 0 + %5356 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5215 + %5357 = select i1 %5328, i32 4, i32 0 + %5358 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5218 + %5359 = select i1 %5329, i32 4, i32 0 + %5360 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5221 + %5361 = select i1 %5330, i32 4, i32 0 + %5362 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98560), i32 %5224 + %5363 = select i1 %5331, i32 4, i32 0 + %5364 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5101 + %5365 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5104 + %5366 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5107 + %5367 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 65536), i32 %5110 + %5368 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5179 + %5369 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5182 + %5370 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5185 + %5371 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5188 + %5372 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5191 + %5373 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5194 + %5374 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5197 + %5375 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5200 + %5376 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5203 + %5377 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5206 + %5378 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5209 + %5379 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5212 + %5380 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5215 + %5381 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5218 + %5382 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5221 + %5383 = getelementptr inbounds nuw i8, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99072), i32 %5224 + %5384 = add i32 %5008, -2 + %5385 = add nsw i32 %5008, -1 + %5386 = sext i32 %5073 to i64 + %5387 = sext i32 %5074 to i64 + %5388 = sext i32 %5075 to i64 + %5389 = sext i32 %5076 to i64 + %5390 = sext i32 %5077 to i64 + %5391 = sext i32 %5078 to i64 + %5392 = sext i32 %5079 to i64 + %5393 = sext i32 %5080 to i64 + %5394 = icmp sgt i32 %5081, 0 + %5395 = icmp slt i32 %5069, %17 + %5396 = icmp slt i32 %5070, %17 + %5397 = icmp slt i32 %5071, %17 + %5398 = icmp slt i32 %5072, %17 + %5399 = and i1 %5394, %5395 + %5400 = and i1 %5394, %5396 + %5401 = and i1 %5394, %5397 + %5402 = and i1 %5394, %5398 + %5403 = select i1 %5399, i32 16, i32 0 + %5404 = select i1 %5400, i32 16, i32 0 + %5405 = select i1 %5401, i32 16, i32 0 + %5406 = select i1 %5402, i32 16, i32 0 + %5407 = icmp slt i32 %5041, %17 + %5408 = icmp slt i32 %5042, %17 + %5409 = icmp slt i32 %5043, %17 + %5410 = icmp slt i32 %5044, %17 + %5411 = icmp slt i32 %5046, %17 + %5412 = icmp slt i32 %5048, %17 + %5413 = icmp slt i32 %5050, %17 + %5414 = icmp slt i32 %5052, %17 + %5415 = icmp slt i32 %5054, %17 + %5416 = icmp slt i32 %5056, %17 + %5417 = icmp slt i32 %5058, %17 + %5418 = icmp slt i32 %5060, %17 + %5419 = icmp slt i32 %5062, %17 + %5420 = icmp slt i32 %5064, %17 + %5421 = icmp slt i32 %5066, %17 + %5422 = icmp slt i32 %5068, %17 + %5423 = sext i32 %5041 to i64 + %5424 = sext i32 %5042 to i64 + %5425 = sext i32 %5043 to i64 + %5426 = sext i32 %5044 to i64 + %5427 = sext i32 %5046 to i64 + %5428 = sext i32 %5048 to i64 + %5429 = sext i32 %5050 to i64 + %5430 = sext i32 %5052 to i64 + %5431 = sext i32 %5054 to i64 + %5432 = sext i32 %5056 to i64 + %5433 = sext i32 %5058 to i64 + %5434 = sext i32 %5060 to i64 + %5435 = sext i32 %5062 to i64 + %5436 = sext i32 %5064 to i64 + %5437 = sext i32 %5066 to i64 + %5438 = sext i32 %5068 to i64 + %5439 = and i1 %5394, %5407 + %5440 = and i1 %5394, %5408 + %5441 = and i1 %5394, %5409 + %5442 = and i1 %5394, %5410 + %5443 = and i1 %5394, %5411 + %5444 = and i1 %5394, %5412 + %5445 = and i1 %5394, %5413 + %5446 = and i1 %5394, %5414 + %5447 = and i1 %5394, %5415 + %5448 = and i1 %5394, %5416 + %5449 = and i1 %5394, %5417 + %5450 = and i1 %5394, %5418 + %5451 = and i1 %5394, %5419 + %5452 = and i1 %5394, %5420 + %5453 = and i1 %5394, %5421 + %5454 = and i1 %5394, %5422 + %5455 = select i1 %5439, i32 4, i32 0 + %5456 = select i1 %5440, i32 4, i32 0 + %5457 = select i1 %5441, i32 4, i32 0 + %5458 = select i1 %5442, i32 4, i32 0 + %5459 = select i1 %5443, i32 4, i32 0 + %5460 = select i1 %5444, i32 4, i32 0 + %5461 = select i1 %5445, i32 4, i32 0 + %5462 = select i1 %5446, i32 4, i32 0 + %5463 = select i1 %5447, i32 4, i32 0 + %5464 = select i1 %5448, i32 4, i32 0 + %5465 = select i1 %5449, i32 4, i32 0 + %5466 = select i1 %5450, i32 4, i32 0 + %5467 = select i1 %5451, i32 4, i32 0 + %5468 = select i1 %5452, i32 4, i32 0 + %5469 = select i1 %5453, i32 4, i32 0 + %5470 = select i1 %5454, i32 4, i32 0 + %5471 = icmp sgt i32 %5082, 1 + %5472 = or disjoint i32 %5041, 64 + %5473 = or disjoint i32 %5042, 64 + %5474 = or disjoint i32 %5043, 64 + %5475 = or disjoint i32 %5044, 64 + %5476 = or disjoint i32 %5046, 64 + %5477 = or disjoint i32 %5048, 64 + %5478 = or disjoint i32 %5050, 64 + %5479 = or disjoint i32 %5052, 64 + %5480 = or disjoint i32 %5054, 64 + %5481 = or disjoint i32 %5056, 64 + %5482 = or disjoint i32 %5058, 64 + %5483 = or disjoint i32 %5060, 64 + %5484 = or disjoint i32 %5062, 64 + %5485 = or disjoint i32 %5064, 64 + %5486 = or disjoint i32 %5066, 64 + %5487 = or disjoint i32 %5068, 64 + %5488 = or disjoint i32 %5069, 64 + %5489 = or disjoint i32 %5070, 64 + %5490 = or disjoint i32 %5071, 64 + %5491 = or disjoint i32 %5072, 64 + %5492 = icmp slt i32 %5488, %17 + %5493 = icmp slt i32 %5489, %17 + %5494 = icmp slt i32 %5490, %17 + %5495 = icmp slt i32 %5491, %17 + %5496 = and i1 %5471, %5492 + %5497 = and i1 %5471, %5493 + %5498 = and i1 %5471, %5494 + %5499 = and i1 %5471, %5495 + %5500 = select i1 %5496, i32 16, i32 0 + %5501 = select i1 %5497, i32 16, i32 0 + %5502 = select i1 %5498, i32 16, i32 0 + %5503 = select i1 %5499, i32 16, i32 0 + %5504 = icmp slt i32 %5472, %17 + %5505 = icmp slt i32 %5473, %17 + %5506 = icmp slt i32 %5474, %17 + %5507 = icmp slt i32 %5475, %17 + %5508 = icmp slt i32 %5476, %17 + %5509 = icmp slt i32 %5477, %17 + %5510 = icmp slt i32 %5478, %17 + %5511 = icmp slt i32 %5479, %17 + %5512 = icmp slt i32 %5480, %17 + %5513 = icmp slt i32 %5481, %17 + %5514 = icmp slt i32 %5482, %17 + %5515 = icmp slt i32 %5483, %17 + %5516 = icmp slt i32 %5484, %17 + %5517 = icmp slt i32 %5485, %17 + %5518 = icmp slt i32 %5486, %17 + %5519 = icmp slt i32 %5487, %17 + %5520 = sext i32 %5472 to i64 + %5521 = sext i32 %5473 to i64 + %5522 = sext i32 %5474 to i64 + %5523 = sext i32 %5475 to i64 + %5524 = sext i32 %5476 to i64 + %5525 = sext i32 %5477 to i64 + %5526 = sext i32 %5478 to i64 + %5527 = sext i32 %5479 to i64 + %5528 = sext i32 %5480 to i64 + %5529 = sext i32 %5481 to i64 + %5530 = sext i32 %5482 to i64 + %5531 = sext i32 %5483 to i64 + %5532 = sext i32 %5484 to i64 + %5533 = sext i32 %5485 to i64 + %5534 = sext i32 %5486 to i64 + %5535 = sext i32 %5487 to i64 + %5536 = and i1 %5471, %5504 + %5537 = and i1 %5471, %5505 + %5538 = and i1 %5471, %5506 + %5539 = and i1 %5471, %5507 + %5540 = and i1 %5471, %5508 + %5541 = and i1 %5471, %5509 + %5542 = and i1 %5471, %5510 + %5543 = and i1 %5471, %5511 + %5544 = and i1 %5471, %5512 + %5545 = and i1 %5471, %5513 + %5546 = and i1 %5471, %5514 + %5547 = and i1 %5471, %5515 + %5548 = and i1 %5471, %5516 + %5549 = and i1 %5471, %5517 + %5550 = and i1 %5471, %5518 + %5551 = and i1 %5471, %5519 + %5552 = select i1 %5536, i32 4, i32 0 + %5553 = select i1 %5537, i32 4, i32 0 + %5554 = select i1 %5538, i32 4, i32 0 + %5555 = select i1 %5539, i32 4, i32 0 + %5556 = select i1 %5540, i32 4, i32 0 + %5557 = select i1 %5541, i32 4, i32 0 + %5558 = select i1 %5542, i32 4, i32 0 + %5559 = select i1 %5543, i32 4, i32 0 + %5560 = select i1 %5544, i32 4, i32 0 + %5561 = select i1 %5545, i32 4, i32 0 + %5562 = select i1 %5546, i32 4, i32 0 + %5563 = select i1 %5547, i32 4, i32 0 + %5564 = select i1 %5548, i32 4, i32 0 + %5565 = select i1 %5549, i32 4, i32 0 + %5566 = select i1 %5550, i32 4, i32 0 + %5567 = select i1 %5551, i32 4, i32 0 + %5568 = add i32 %5082, -2 + %5569 = add nsw i32 %5082, -1 + %smax2265 = tail call i32 @llvm.smax.i32(i32 %5008, i32 1), !dbg !249 + %smax2267 = tail call i32 @llvm.smax.i32(i32 %5082, i32 1), !dbg !249 + %5570 = zext nneg i32 %4977 to i64, !dbg !249 + %5571 = insertelement <16 x i32> poison, i32 %17, i64 0 + %5572 = shufflevector <16 x i32> %5571, <16 x i32> poison, <16 x i32> zeroinitializer + %5573 = extractelement <2 x i32> %5033, i64 1 + %5574 = extractelement <2 x i32> %5033, i64 0 + br label %5575, !dbg !249 + +5575: ; preds = %4732, %._crit_edge1874 + %indvars.iv = phi i64 [ 0, %4732 ], [ %indvars.iv.next, %._crit_edge1874 ] + %5576 = phi float [ 0.000000e+00, %4732 ], [ %11035, %._crit_edge1874 ] + %5577 = phi float [ 0.000000e+00, %4732 ], [ %11036, %._crit_edge1874 ] + %5578 = phi float [ 0.000000e+00, %4732 ], [ %11037, %._crit_edge1874 ] + %5579 = phi float [ 0.000000e+00, %4732 ], [ %11038, %._crit_edge1874 ] + %5580 = phi float [ 0.000000e+00, %4732 ], [ %11039, %._crit_edge1874 ] + %5581 = phi float [ 0.000000e+00, %4732 ], [ %11040, %._crit_edge1874 ] + %5582 = phi float [ 0.000000e+00, %4732 ], [ %11041, %._crit_edge1874 ] + %5583 = phi float [ 0.000000e+00, %4732 ], [ %11042, %._crit_edge1874 ] + %5584 = phi float [ 0.000000e+00, %4732 ], [ %11043, %._crit_edge1874 ] + %5585 = phi float [ 0.000000e+00, %4732 ], [ %11044, %._crit_edge1874 ] + %5586 = phi float [ 0.000000e+00, %4732 ], [ %11045, %._crit_edge1874 ] + %5587 = phi float [ 0.000000e+00, %4732 ], [ %11046, %._crit_edge1874 ] + %5588 = phi float [ 0.000000e+00, %4732 ], [ %11047, %._crit_edge1874 ] + %5589 = phi float [ 0.000000e+00, %4732 ], [ %11048, %._crit_edge1874 ] + %5590 = phi float [ 0.000000e+00, %4732 ], [ %11049, %._crit_edge1874 ] + %5591 = phi float [ 0.000000e+00, %4732 ], [ %11050, %._crit_edge1874 ] + %5592 = phi float [ 0.000000e+00, %4732 ], [ %11051, %._crit_edge1874 ] + %5593 = phi float [ 0.000000e+00, %4732 ], [ %11052, %._crit_edge1874 ] + %5594 = phi float [ 0.000000e+00, %4732 ], [ %11053, %._crit_edge1874 ] + %5595 = phi float [ 0.000000e+00, %4732 ], [ %11054, %._crit_edge1874 ] + %5596 = phi float [ 0.000000e+00, %4732 ], [ %11055, %._crit_edge1874 ] + %5597 = phi float [ 0.000000e+00, %4732 ], [ %11056, %._crit_edge1874 ] + %5598 = phi float [ 0.000000e+00, %4732 ], [ %11057, %._crit_edge1874 ] + %5599 = phi float [ 0.000000e+00, %4732 ], [ %11058, %._crit_edge1874 ] + %5600 = phi float [ 0.000000e+00, %4732 ], [ %11059, %._crit_edge1874 ] + %5601 = phi float [ 0.000000e+00, %4732 ], [ %11060, %._crit_edge1874 ] + %5602 = phi float [ 0.000000e+00, %4732 ], [ %11061, %._crit_edge1874 ] + %5603 = phi float [ 0.000000e+00, %4732 ], [ %11062, %._crit_edge1874 ] + %5604 = phi float [ 0.000000e+00, %4732 ], [ %11063, %._crit_edge1874 ] + %5605 = phi float [ 0.000000e+00, %4732 ], [ %11064, %._crit_edge1874 ] + %5606 = phi float [ 0.000000e+00, %4732 ], [ %11065, %._crit_edge1874 ] + %5607 = phi float [ 0.000000e+00, %4732 ], [ %11066, %._crit_edge1874 ] + %5608 = phi float [ 0.000000e+00, %4732 ], [ %11067, %._crit_edge1874 ] + %5609 = phi float [ 0.000000e+00, %4732 ], [ %11068, %._crit_edge1874 ] + %5610 = phi float [ 0.000000e+00, %4732 ], [ %11069, %._crit_edge1874 ] + %5611 = phi float [ 0.000000e+00, %4732 ], [ %11070, %._crit_edge1874 ] + %5612 = phi float [ 0.000000e+00, %4732 ], [ %11071, %._crit_edge1874 ] + %5613 = phi float [ 0.000000e+00, %4732 ], [ %11072, %._crit_edge1874 ] + %5614 = phi float [ 0.000000e+00, %4732 ], [ %11073, %._crit_edge1874 ] + %5615 = phi float [ 0.000000e+00, %4732 ], [ %11074, %._crit_edge1874 ] + %5616 = phi float [ 0.000000e+00, %4732 ], [ %11075, %._crit_edge1874 ] + %5617 = phi float [ 0.000000e+00, %4732 ], [ %11076, %._crit_edge1874 ] + %5618 = phi float [ 0.000000e+00, %4732 ], [ %11077, %._crit_edge1874 ] + %5619 = phi float [ 0.000000e+00, %4732 ], [ %11078, %._crit_edge1874 ] + %5620 = phi float [ 0.000000e+00, %4732 ], [ %11079, %._crit_edge1874 ] + %5621 = phi float [ 0.000000e+00, %4732 ], [ %11080, %._crit_edge1874 ] + %5622 = phi float [ 0.000000e+00, %4732 ], [ %11081, %._crit_edge1874 ] + %5623 = phi float [ 0.000000e+00, %4732 ], [ %11082, %._crit_edge1874 ] + %5624 = phi float [ 0.000000e+00, %4732 ], [ %11083, %._crit_edge1874 ] + %5625 = phi float [ 0.000000e+00, %4732 ], [ %11084, %._crit_edge1874 ] + %5626 = phi float [ 0.000000e+00, %4732 ], [ %11085, %._crit_edge1874 ] + %5627 = phi float [ 0.000000e+00, %4732 ], [ %11086, %._crit_edge1874 ] + %5628 = phi float [ 0.000000e+00, %4732 ], [ %11087, %._crit_edge1874 ] + %5629 = phi float [ 0.000000e+00, %4732 ], [ %11088, %._crit_edge1874 ] + %5630 = phi float [ 0.000000e+00, %4732 ], [ %11089, %._crit_edge1874 ] + %5631 = phi float [ 0.000000e+00, %4732 ], [ %11090, %._crit_edge1874 ] + %5632 = phi float [ 0.000000e+00, %4732 ], [ %11091, %._crit_edge1874 ] + %5633 = phi float [ 0.000000e+00, %4732 ], [ %11092, %._crit_edge1874 ] + %5634 = phi float [ 0.000000e+00, %4732 ], [ %11093, %._crit_edge1874 ] + %5635 = phi float [ 0.000000e+00, %4732 ], [ %11094, %._crit_edge1874 ] + %5636 = phi float [ 0.000000e+00, %4732 ], [ %11095, %._crit_edge1874 ] + %5637 = phi float [ 0.000000e+00, %4732 ], [ %11096, %._crit_edge1874 ] + %5638 = phi float [ 0.000000e+00, %4732 ], [ %11097, %._crit_edge1874 ] + %5639 = phi float [ 0.000000e+00, %4732 ], [ %11098, %._crit_edge1874 ] + %5640 = phi float [ 0.000000e+00, %4732 ], [ %10971, %._crit_edge1874 ] + %5641 = phi float [ 0.000000e+00, %4732 ], [ %10972, %._crit_edge1874 ] + %5642 = phi float [ 0.000000e+00, %4732 ], [ %10973, %._crit_edge1874 ] + %5643 = phi float [ 0.000000e+00, %4732 ], [ %10974, %._crit_edge1874 ] + %5644 = phi float [ 0.000000e+00, %4732 ], [ %10975, %._crit_edge1874 ] + %5645 = phi float [ 0.000000e+00, %4732 ], [ %10976, %._crit_edge1874 ] + %5646 = phi float [ 0.000000e+00, %4732 ], [ %10977, %._crit_edge1874 ] + %5647 = phi float [ 0.000000e+00, %4732 ], [ %10978, %._crit_edge1874 ] + %5648 = phi float [ 0.000000e+00, %4732 ], [ %10979, %._crit_edge1874 ] + %5649 = phi float [ 0.000000e+00, %4732 ], [ %10980, %._crit_edge1874 ] + %5650 = phi float [ 0.000000e+00, %4732 ], [ %10981, %._crit_edge1874 ] + %5651 = phi float [ 0.000000e+00, %4732 ], [ %10982, %._crit_edge1874 ] + %5652 = phi float [ 0.000000e+00, %4732 ], [ %10983, %._crit_edge1874 ] + %5653 = phi float [ 0.000000e+00, %4732 ], [ %10984, %._crit_edge1874 ] + %5654 = phi float [ 0.000000e+00, %4732 ], [ %10985, %._crit_edge1874 ] + %5655 = phi float [ 0.000000e+00, %4732 ], [ %10986, %._crit_edge1874 ] + %5656 = phi float [ 0.000000e+00, %4732 ], [ %10987, %._crit_edge1874 ] + %5657 = phi float [ 0.000000e+00, %4732 ], [ %10988, %._crit_edge1874 ] + %5658 = phi float [ 0.000000e+00, %4732 ], [ %10989, %._crit_edge1874 ] + %5659 = phi float [ 0.000000e+00, %4732 ], [ %10990, %._crit_edge1874 ] + %5660 = phi float [ 0.000000e+00, %4732 ], [ %10991, %._crit_edge1874 ] + %5661 = phi float [ 0.000000e+00, %4732 ], [ %10992, %._crit_edge1874 ] + %5662 = phi float [ 0.000000e+00, %4732 ], [ %10993, %._crit_edge1874 ] + %5663 = phi float [ 0.000000e+00, %4732 ], [ %10994, %._crit_edge1874 ] + %5664 = phi float [ 0.000000e+00, %4732 ], [ %10995, %._crit_edge1874 ] + %5665 = phi float [ 0.000000e+00, %4732 ], [ %10996, %._crit_edge1874 ] + %5666 = phi float [ 0.000000e+00, %4732 ], [ %10997, %._crit_edge1874 ] + %5667 = phi float [ 0.000000e+00, %4732 ], [ %10998, %._crit_edge1874 ] + %5668 = phi float [ 0.000000e+00, %4732 ], [ %10999, %._crit_edge1874 ] + %5669 = phi float [ 0.000000e+00, %4732 ], [ %11000, %._crit_edge1874 ] + %5670 = phi float [ 0.000000e+00, %4732 ], [ %11001, %._crit_edge1874 ] + %5671 = phi float [ 0.000000e+00, %4732 ], [ %11002, %._crit_edge1874 ] + %5672 = phi float [ 0.000000e+00, %4732 ], [ %11003, %._crit_edge1874 ] + %5673 = phi float [ 0.000000e+00, %4732 ], [ %11004, %._crit_edge1874 ] + %5674 = phi float [ 0.000000e+00, %4732 ], [ %11005, %._crit_edge1874 ] + %5675 = phi float [ 0.000000e+00, %4732 ], [ %11006, %._crit_edge1874 ] + %5676 = phi float [ 0.000000e+00, %4732 ], [ %11007, %._crit_edge1874 ] + %5677 = phi float [ 0.000000e+00, %4732 ], [ %11008, %._crit_edge1874 ] + %5678 = phi float [ 0.000000e+00, %4732 ], [ %11009, %._crit_edge1874 ] + %5679 = phi float [ 0.000000e+00, %4732 ], [ %11010, %._crit_edge1874 ] + %5680 = phi float [ 0.000000e+00, %4732 ], [ %11011, %._crit_edge1874 ] + %5681 = phi float [ 0.000000e+00, %4732 ], [ %11012, %._crit_edge1874 ] + %5682 = phi float [ 0.000000e+00, %4732 ], [ %11013, %._crit_edge1874 ] + %5683 = phi float [ 0.000000e+00, %4732 ], [ %11014, %._crit_edge1874 ] + %5684 = phi float [ 0.000000e+00, %4732 ], [ %11015, %._crit_edge1874 ] + %5685 = phi float [ 0.000000e+00, %4732 ], [ %11016, %._crit_edge1874 ] + %5686 = phi float [ 0.000000e+00, %4732 ], [ %11017, %._crit_edge1874 ] + %5687 = phi float [ 0.000000e+00, %4732 ], [ %11018, %._crit_edge1874 ] + %5688 = phi float [ 0.000000e+00, %4732 ], [ %11019, %._crit_edge1874 ] + %5689 = phi float [ 0.000000e+00, %4732 ], [ %11020, %._crit_edge1874 ] + %5690 = phi float [ 0.000000e+00, %4732 ], [ %11021, %._crit_edge1874 ] + %5691 = phi float [ 0.000000e+00, %4732 ], [ %11022, %._crit_edge1874 ] + %5692 = phi float [ 0.000000e+00, %4732 ], [ %11023, %._crit_edge1874 ] + %5693 = phi float [ 0.000000e+00, %4732 ], [ %11024, %._crit_edge1874 ] + %5694 = phi float [ 0.000000e+00, %4732 ], [ %11025, %._crit_edge1874 ] + %5695 = phi float [ 0.000000e+00, %4732 ], [ %11026, %._crit_edge1874 ] + %5696 = phi float [ 0.000000e+00, %4732 ], [ %11027, %._crit_edge1874 ] + %5697 = phi float [ 0.000000e+00, %4732 ], [ %11028, %._crit_edge1874 ] + %5698 = phi float [ 0.000000e+00, %4732 ], [ %11029, %._crit_edge1874 ] + %5699 = phi float [ 0.000000e+00, %4732 ], [ %11030, %._crit_edge1874 ] + %5700 = phi float [ 0.000000e+00, %4732 ], [ %11031, %._crit_edge1874 ] + %5701 = phi float [ 0.000000e+00, %4732 ], [ %11032, %._crit_edge1874 ] + %5702 = phi float [ 0.000000e+00, %4732 ], [ %11033, %._crit_edge1874 ] + %5703 = phi float [ 0.000000e+00, %4732 ], [ %11034, %._crit_edge1874 ] + %5704 = add nuw nsw i64 %indvars.iv, %5570, !dbg !250 + %.tr = trunc i64 %5704 to i32, !dbg !251 + %5705 = shl i32 %.tr, 7, !dbg !251 + %5706 = add i32 %5705, %4978, !dbg !251 + %5707 = sext i32 %5706 to i64, !dbg !252 + %5708 = trunc nuw nsw i64 %5704 to i32, !dbg !253 + %5709 = mul i32 %29, %5708, !dbg !253 + %5710 = add i32 %5709, %4979, !dbg !254 + %5711 = sext i32 %5710 to i64, !dbg !255 + %5712 = trunc i64 %5704 to i32, !dbg !256 + %5713 = add i32 %4980, %5712, !dbg !256 + %5714 = mul i32 %5713, %17, !dbg !256 + %5715 = sext i32 %5714 to i64, !dbg !257 + %5716 = getelementptr bfloat, ptr addrspace(1) %0, i64 %5707, !dbg !258 + %5717 = getelementptr bfloat, ptr addrspace(1) %5, i64 %5711, !dbg !259 + %5718 = getelementptr float, ptr addrspace(1) %3, i64 %5715, !dbg !260 + %5719 = getelementptr float, ptr addrspace(1) %4, i64 %5715, !dbg !261 + %5720 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5083, !dbg !262 + %5721 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5084, !dbg !262 + %5722 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5085, !dbg !262 + %5723 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5086, !dbg !262 + %5724 = getelementptr bfloat, ptr addrspace(1) %5720, i64 %4768, !dbg !263 + %5725 = getelementptr bfloat, ptr addrspace(1) %5721, i64 %4768, !dbg !263 + %5726 = getelementptr bfloat, ptr addrspace(1) %5722, i64 %4768, !dbg !263 + %5727 = getelementptr bfloat, ptr addrspace(1) %5723, i64 %4768, !dbg !263 + %5728 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5087, !dbg !264 + %5729 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5088, !dbg !264 + %5730 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5089, !dbg !264 + %5731 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5090, !dbg !264 + %5732 = getelementptr bfloat, ptr addrspace(1) %5728, i64 %4768, !dbg !265 + %5733 = getelementptr bfloat, ptr addrspace(1) %5729, i64 %4768, !dbg !265 + %5734 = getelementptr bfloat, ptr addrspace(1) %5730, i64 %4768, !dbg !265 + %5735 = getelementptr bfloat, ptr addrspace(1) %5731, i64 %4768, !dbg !265 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5102, ptr addrspace(1) %5724, i32 %5103) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5105, ptr addrspace(1) %5725, i32 %5106) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5108, ptr addrspace(1) %5726, i32 %5109) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5111, ptr addrspace(1) %5727, i32 %5112) #3, !dbg !266 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !266 + %5736 = getelementptr float, ptr addrspace(1) %5718, i64 %5145, !dbg !267 + %5737 = getelementptr float, ptr addrspace(1) %5718, i64 %5146, !dbg !267 + %5738 = getelementptr float, ptr addrspace(1) %5718, i64 %5147, !dbg !267 + %5739 = getelementptr float, ptr addrspace(1) %5718, i64 %5148, !dbg !267 + %5740 = getelementptr float, ptr addrspace(1) %5718, i64 %5149, !dbg !267 + %5741 = getelementptr float, ptr addrspace(1) %5718, i64 %5150, !dbg !267 + %5742 = getelementptr float, ptr addrspace(1) %5718, i64 %5151, !dbg !267 + %5743 = getelementptr float, ptr addrspace(1) %5718, i64 %5152, !dbg !267 + %5744 = getelementptr float, ptr addrspace(1) %5718, i64 %5153, !dbg !267 + %5745 = getelementptr float, ptr addrspace(1) %5718, i64 %5154, !dbg !267 + %5746 = getelementptr float, ptr addrspace(1) %5718, i64 %5155, !dbg !267 + %5747 = getelementptr float, ptr addrspace(1) %5718, i64 %5156, !dbg !267 + %5748 = getelementptr float, ptr addrspace(1) %5718, i64 %5157, !dbg !267 + %5749 = getelementptr float, ptr addrspace(1) %5718, i64 %5158, !dbg !267 + %5750 = getelementptr float, ptr addrspace(1) %5718, i64 %5159, !dbg !267 + %5751 = getelementptr float, ptr addrspace(1) %5718, i64 %5160, !dbg !267 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5180, ptr addrspace(1) %5736, i32 %5181, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5183, ptr addrspace(1) %5737, i32 %5184, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5186, ptr addrspace(1) %5738, i32 %5187, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5189, ptr addrspace(1) %5739, i32 %5190, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5192, ptr addrspace(1) %5740, i32 %5193, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5195, ptr addrspace(1) %5741, i32 %5196, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5198, ptr addrspace(1) %5742, i32 %5199, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5201, ptr addrspace(1) %5743, i32 %5202, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5204, ptr addrspace(1) %5744, i32 %5205, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5207, ptr addrspace(1) %5745, i32 %5208, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5210, ptr addrspace(1) %5746, i32 %5211, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5213, ptr addrspace(1) %5747, i32 %5214, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5216, ptr addrspace(1) %5748, i32 %5217, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5219, ptr addrspace(1) %5749, i32 %5220, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5222, ptr addrspace(1) %5750, i32 %5223, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5225, ptr addrspace(1) %5751, i32 %5226, i1 %5178) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5227, ptr addrspace(1) %5732, i32 %5103) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5228, ptr addrspace(1) %5733, i32 %5106) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5229, ptr addrspace(1) %5734, i32 %5109) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5230, ptr addrspace(1) %5735, i32 %5112) #3, !dbg !269 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !269 + %5752 = getelementptr float, ptr addrspace(1) %5719, i64 %5145, !dbg !270 + %5753 = getelementptr float, ptr addrspace(1) %5719, i64 %5146, !dbg !270 + %5754 = getelementptr float, ptr addrspace(1) %5719, i64 %5147, !dbg !270 + %5755 = getelementptr float, ptr addrspace(1) %5719, i64 %5148, !dbg !270 + %5756 = getelementptr float, ptr addrspace(1) %5719, i64 %5149, !dbg !270 + %5757 = getelementptr float, ptr addrspace(1) %5719, i64 %5150, !dbg !270 + %5758 = getelementptr float, ptr addrspace(1) %5719, i64 %5151, !dbg !270 + %5759 = getelementptr float, ptr addrspace(1) %5719, i64 %5152, !dbg !270 + %5760 = getelementptr float, ptr addrspace(1) %5719, i64 %5153, !dbg !270 + %5761 = getelementptr float, ptr addrspace(1) %5719, i64 %5154, !dbg !270 + %5762 = getelementptr float, ptr addrspace(1) %5719, i64 %5155, !dbg !270 + %5763 = getelementptr float, ptr addrspace(1) %5719, i64 %5156, !dbg !270 + %5764 = getelementptr float, ptr addrspace(1) %5719, i64 %5157, !dbg !270 + %5765 = getelementptr float, ptr addrspace(1) %5719, i64 %5158, !dbg !270 + %5766 = getelementptr float, ptr addrspace(1) %5719, i64 %5159, !dbg !270 + %5767 = getelementptr float, ptr addrspace(1) %5719, i64 %5160, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5231, ptr addrspace(1) %5752, i32 %5181, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5232, ptr addrspace(1) %5753, i32 %5184, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5233, ptr addrspace(1) %5754, i32 %5187, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5234, ptr addrspace(1) %5755, i32 %5190, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5235, ptr addrspace(1) %5756, i32 %5193, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5236, ptr addrspace(1) %5757, i32 %5196, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5237, ptr addrspace(1) %5758, i32 %5199, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5238, ptr addrspace(1) %5759, i32 %5202, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5239, ptr addrspace(1) %5760, i32 %5205, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5240, ptr addrspace(1) %5761, i32 %5208, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5241, ptr addrspace(1) %5762, i32 %5211, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5242, ptr addrspace(1) %5763, i32 %5214, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5243, ptr addrspace(1) %5764, i32 %5217, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5244, ptr addrspace(1) %5765, i32 %5220, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5245, ptr addrspace(1) %5766, i32 %5223, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5246, ptr addrspace(1) %5767, i32 %5226, i1 %5178) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + %5768 = getelementptr i8, ptr addrspace(1) %5724, i64 524288, !dbg !272 + %5769 = getelementptr i8, ptr addrspace(1) %5725, i64 524288, !dbg !272 + %5770 = getelementptr i8, ptr addrspace(1) %5726, i64 524288, !dbg !272 + %5771 = getelementptr i8, ptr addrspace(1) %5727, i64 524288, !dbg !272 + %5772 = getelementptr i8, ptr addrspace(1) %5732, i64 16384, !dbg !273 + %5773 = getelementptr i8, ptr addrspace(1) %5733, i64 16384, !dbg !273 + %5774 = getelementptr i8, ptr addrspace(1) %5734, i64 16384, !dbg !273 + %5775 = getelementptr i8, ptr addrspace(1) %5735, i64 16384, !dbg !273 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5276, ptr addrspace(1) %5768, i32 %5277) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5278, ptr addrspace(1) %5769, i32 %5279) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5280, ptr addrspace(1) %5770, i32 %5281) #3, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5282, ptr addrspace(1) %5771, i32 %5283) #3, !dbg !266 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !266 + %5776 = getelementptr float, ptr addrspace(1) %5718, i64 %5300, !dbg !267 + %5777 = getelementptr float, ptr addrspace(1) %5718, i64 %5301, !dbg !267 + %5778 = getelementptr float, ptr addrspace(1) %5718, i64 %5302, !dbg !267 + %5779 = getelementptr float, ptr addrspace(1) %5718, i64 %5303, !dbg !267 + %5780 = getelementptr float, ptr addrspace(1) %5718, i64 %5304, !dbg !267 + %5781 = getelementptr float, ptr addrspace(1) %5718, i64 %5305, !dbg !267 + %5782 = getelementptr float, ptr addrspace(1) %5718, i64 %5306, !dbg !267 + %5783 = getelementptr float, ptr addrspace(1) %5718, i64 %5307, !dbg !267 + %5784 = getelementptr float, ptr addrspace(1) %5718, i64 %5308, !dbg !267 + %5785 = getelementptr float, ptr addrspace(1) %5718, i64 %5309, !dbg !267 + %5786 = getelementptr float, ptr addrspace(1) %5718, i64 %5310, !dbg !267 + %5787 = getelementptr float, ptr addrspace(1) %5718, i64 %5311, !dbg !267 + %5788 = getelementptr float, ptr addrspace(1) %5718, i64 %5312, !dbg !267 + %5789 = getelementptr float, ptr addrspace(1) %5718, i64 %5313, !dbg !267 + %5790 = getelementptr float, ptr addrspace(1) %5718, i64 %5314, !dbg !267 + %5791 = getelementptr float, ptr addrspace(1) %5718, i64 %5315, !dbg !267 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5332, ptr addrspace(1) %5776, i32 %5333, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5334, ptr addrspace(1) %5777, i32 %5335, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5336, ptr addrspace(1) %5778, i32 %5337, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5338, ptr addrspace(1) %5779, i32 %5339, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5340, ptr addrspace(1) %5780, i32 %5341, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5342, ptr addrspace(1) %5781, i32 %5343, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5344, ptr addrspace(1) %5782, i32 %5345, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5346, ptr addrspace(1) %5783, i32 %5347, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5348, ptr addrspace(1) %5784, i32 %5349, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5350, ptr addrspace(1) %5785, i32 %5351, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5352, ptr addrspace(1) %5786, i32 %5353, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5354, ptr addrspace(1) %5787, i32 %5355, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5356, ptr addrspace(1) %5788, i32 %5357, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5358, ptr addrspace(1) %5789, i32 %5359, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5360, ptr addrspace(1) %5790, i32 %5361, i1 %5178) #3, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5362, ptr addrspace(1) %5791, i32 %5363, i1 %5178) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5364, ptr addrspace(1) %5772, i32 %5277) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5365, ptr addrspace(1) %5773, i32 %5279) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5366, ptr addrspace(1) %5774, i32 %5281) #3, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5367, ptr addrspace(1) %5775, i32 %5283) #3, !dbg !269 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !269 + %5792 = getelementptr float, ptr addrspace(1) %5719, i64 %5300, !dbg !270 + %5793 = getelementptr float, ptr addrspace(1) %5719, i64 %5301, !dbg !270 + %5794 = getelementptr float, ptr addrspace(1) %5719, i64 %5302, !dbg !270 + %5795 = getelementptr float, ptr addrspace(1) %5719, i64 %5303, !dbg !270 + %5796 = getelementptr float, ptr addrspace(1) %5719, i64 %5304, !dbg !270 + %5797 = getelementptr float, ptr addrspace(1) %5719, i64 %5305, !dbg !270 + %5798 = getelementptr float, ptr addrspace(1) %5719, i64 %5306, !dbg !270 + %5799 = getelementptr float, ptr addrspace(1) %5719, i64 %5307, !dbg !270 + %5800 = getelementptr float, ptr addrspace(1) %5719, i64 %5308, !dbg !270 + %5801 = getelementptr float, ptr addrspace(1) %5719, i64 %5309, !dbg !270 + %5802 = getelementptr float, ptr addrspace(1) %5719, i64 %5310, !dbg !270 + %5803 = getelementptr float, ptr addrspace(1) %5719, i64 %5311, !dbg !270 + %5804 = getelementptr float, ptr addrspace(1) %5719, i64 %5312, !dbg !270 + %5805 = getelementptr float, ptr addrspace(1) %5719, i64 %5313, !dbg !270 + %5806 = getelementptr float, ptr addrspace(1) %5719, i64 %5314, !dbg !270 + %5807 = getelementptr float, ptr addrspace(1) %5719, i64 %5315, !dbg !270 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5368, ptr addrspace(1) %5792, i32 %5333, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5369, ptr addrspace(1) %5793, i32 %5335, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5370, ptr addrspace(1) %5794, i32 %5337, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5371, ptr addrspace(1) %5795, i32 %5339, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5372, ptr addrspace(1) %5796, i32 %5341, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5373, ptr addrspace(1) %5797, i32 %5343, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5374, ptr addrspace(1) %5798, i32 %5345, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5375, ptr addrspace(1) %5799, i32 %5347, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5376, ptr addrspace(1) %5800, i32 %5349, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5377, ptr addrspace(1) %5801, i32 %5351, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5378, ptr addrspace(1) %5802, i32 %5353, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5379, ptr addrspace(1) %5803, i32 %5355, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5380, ptr addrspace(1) %5804, i32 %5357, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5381, ptr addrspace(1) %5805, i32 %5359, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5382, ptr addrspace(1) %5806, i32 %5361, i1 %5178) #3, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5383, ptr addrspace(1) %5807, i32 %5363, i1 %5178) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + br i1 %5091, label %.lr.ph1700, label %._crit_edge1701, !dbg !274 + +.lr.ph1700: ; preds = %5575, %__nv_exp2f.exit1417 + %5808 = phi i32 [ %8207, %__nv_exp2f.exit1417 ], [ 64, %5575 ] + %5809 = phi i32 [ %.pn2191683, %__nv_exp2f.exit1417 ], [ %5113, %5575 ] + %5810 = phi i32 [ %.pn2171684, %__nv_exp2f.exit1417 ], [ %5115, %5575 ] + %5811 = phi i32 [ %.pn2151685, %__nv_exp2f.exit1417 ], [ %5117, %5575 ] + %5812 = phi i32 [ %.pn2131686, %__nv_exp2f.exit1417 ], [ %5119, %5575 ] + %5813 = phi i32 [ %.pn2111687, %__nv_exp2f.exit1417 ], [ %5121, %5575 ] + %5814 = phi i32 [ %.pn2091688, %__nv_exp2f.exit1417 ], [ %5123, %5575 ] + %5815 = phi i32 [ %.pn2071689, %__nv_exp2f.exit1417 ], [ %5125, %5575 ] + %5816 = phi i32 [ %.pn2051690, %__nv_exp2f.exit1417 ], [ %5127, %5575 ] + %5817 = phi i32 [ %.pn2031691, %__nv_exp2f.exit1417 ], [ %5129, %5575 ] + %5818 = phi i32 [ %.pn2011692, %__nv_exp2f.exit1417 ], [ %5131, %5575 ] + %5819 = phi i32 [ %.pn1991693, %__nv_exp2f.exit1417 ], [ %5133, %5575 ] + %5820 = phi i32 [ %.pn1971694, %__nv_exp2f.exit1417 ], [ %5135, %5575 ] + %5821 = phi i32 [ %.pn1951695, %__nv_exp2f.exit1417 ], [ %5137, %5575 ] + %5822 = phi i32 [ %.pn1931696, %__nv_exp2f.exit1417 ], [ %5139, %5575 ] + %5823 = phi i32 [ %.pn1911697, %__nv_exp2f.exit1417 ], [ %5141, %5575 ] + %5824 = phi i32 [ %.pn1891698, %__nv_exp2f.exit1417 ], [ %5143, %5575 ] + %5825 = phi i32 [ %5971, %__nv_exp2f.exit1417 ], [ -1, %5575 ] + %5826 = phi i32 [ %8246, %__nv_exp2f.exit1417 ], [ 1, %5575 ] + %5827 = phi i32 [ %5974, %__nv_exp2f.exit1417 ], [ -1, %5575 ] + %5828 = phi i32 [ %8249, %__nv_exp2f.exit1417 ], [ 1, %5575 ] + %.pn1891698 = phi i32 [ %8235, %__nv_exp2f.exit1417 ], [ %5263, %5575 ] + %.pn1911697 = phi i32 [ %8234, %__nv_exp2f.exit1417 ], [ %5262, %5575 ] + %.pn1931696 = phi i32 [ %8233, %__nv_exp2f.exit1417 ], [ %5261, %5575 ] + %.pn1951695 = phi i32 [ %8232, %__nv_exp2f.exit1417 ], [ %5260, %5575 ] + %.pn1971694 = phi i32 [ %8231, %__nv_exp2f.exit1417 ], [ %5259, %5575 ] + %.pn1991693 = phi i32 [ %8230, %__nv_exp2f.exit1417 ], [ %5258, %5575 ] + %.pn2011692 = phi i32 [ %8229, %__nv_exp2f.exit1417 ], [ %5257, %5575 ] + %.pn2031691 = phi i32 [ %8228, %__nv_exp2f.exit1417 ], [ %5256, %5575 ] + %.pn2051690 = phi i32 [ %8227, %__nv_exp2f.exit1417 ], [ %5255, %5575 ] + %.pn2071689 = phi i32 [ %8226, %__nv_exp2f.exit1417 ], [ %5254, %5575 ] + %.pn2091688 = phi i32 [ %8225, %__nv_exp2f.exit1417 ], [ %5253, %5575 ] + %.pn2111687 = phi i32 [ %8224, %__nv_exp2f.exit1417 ], [ %5252, %5575 ] + %.pn2131686 = phi i32 [ %8223, %__nv_exp2f.exit1417 ], [ %5251, %5575 ] + %.pn2151685 = phi i32 [ %8222, %__nv_exp2f.exit1417 ], [ %5250, %5575 ] + %.pn2171684 = phi i32 [ %8221, %__nv_exp2f.exit1417 ], [ %5249, %5575 ] + %.pn2191683 = phi i32 [ %8220, %__nv_exp2f.exit1417 ], [ %5248, %5575 ] + %5829 = phi i32 [ %8240, %__nv_exp2f.exit1417 ], [ %5264, %5575 ] + %5830 = phi i32 [ %8241, %__nv_exp2f.exit1417 ], [ %5265, %5575 ] + %5831 = phi i32 [ %8242, %__nv_exp2f.exit1417 ], [ %5266, %5575 ] + %5832 = phi i32 [ %8243, %__nv_exp2f.exit1417 ], [ %5267, %5575 ] + %.pn1391682 = phi ptr addrspace(1) [ %8219, %__nv_exp2f.exit1417 ], [ %5775, %5575 ] + %.pn1551681 = phi ptr addrspace(1) [ %8218, %__nv_exp2f.exit1417 ], [ %5774, %5575 ] + %.pn1711680 = phi ptr addrspace(1) [ %8217, %__nv_exp2f.exit1417 ], [ %5773, %5575 ] + %.pn1871679 = phi ptr addrspace(1) [ %8216, %__nv_exp2f.exit1417 ], [ %5772, %5575 ] + %5833 = phi i32 [ %8236, %__nv_exp2f.exit1417 ], [ %5264, %5575 ] + %5834 = phi i32 [ %8237, %__nv_exp2f.exit1417 ], [ %5265, %5575 ] + %5835 = phi i32 [ %8238, %__nv_exp2f.exit1417 ], [ %5266, %5575 ] + %5836 = phi i32 [ %8239, %__nv_exp2f.exit1417 ], [ %5267, %5575 ] + %.pn751678 = phi ptr addrspace(1) [ %8213, %__nv_exp2f.exit1417 ], [ %5771, %5575 ] + %.pn911677 = phi ptr addrspace(1) [ %8212, %__nv_exp2f.exit1417 ], [ %5770, %5575 ] + %.pn1071676 = phi ptr addrspace(1) [ %8211, %__nv_exp2f.exit1417 ], [ %5769, %5575 ] + %.pn1231675 = phi ptr addrspace(1) [ %8210, %__nv_exp2f.exit1417 ], [ %5768, %5575 ] + %5837 = phi float [ %7265, %__nv_exp2f.exit1417 ], [ %5640, %5575 ] + %5838 = phi float [ %7266, %__nv_exp2f.exit1417 ], [ %5641, %5575 ] + %5839 = phi float [ %7267, %__nv_exp2f.exit1417 ], [ %5642, %5575 ] + %5840 = phi float [ %7268, %__nv_exp2f.exit1417 ], [ %5643, %5575 ] + %5841 = phi float [ %7269, %__nv_exp2f.exit1417 ], [ %5644, %5575 ] + %5842 = phi float [ %7270, %__nv_exp2f.exit1417 ], [ %5645, %5575 ] + %5843 = phi float [ %7271, %__nv_exp2f.exit1417 ], [ %5646, %5575 ] + %5844 = phi float [ %7272, %__nv_exp2f.exit1417 ], [ %5647, %5575 ] + %5845 = phi float [ %7273, %__nv_exp2f.exit1417 ], [ %5648, %5575 ] + %5846 = phi float [ %7274, %__nv_exp2f.exit1417 ], [ %5649, %5575 ] + %5847 = phi float [ %7275, %__nv_exp2f.exit1417 ], [ %5650, %5575 ] + %5848 = phi float [ %7276, %__nv_exp2f.exit1417 ], [ %5651, %5575 ] + %5849 = phi float [ %7277, %__nv_exp2f.exit1417 ], [ %5652, %5575 ] + %5850 = phi float [ %7278, %__nv_exp2f.exit1417 ], [ %5653, %5575 ] + %5851 = phi float [ %7279, %__nv_exp2f.exit1417 ], [ %5654, %5575 ] + %5852 = phi float [ %7280, %__nv_exp2f.exit1417 ], [ %5655, %5575 ] + %5853 = phi float [ %7281, %__nv_exp2f.exit1417 ], [ %5656, %5575 ] + %5854 = phi float [ %7282, %__nv_exp2f.exit1417 ], [ %5657, %5575 ] + %5855 = phi float [ %7283, %__nv_exp2f.exit1417 ], [ %5658, %5575 ] + %5856 = phi float [ %7284, %__nv_exp2f.exit1417 ], [ %5659, %5575 ] + %5857 = phi float [ %7285, %__nv_exp2f.exit1417 ], [ %5660, %5575 ] + %5858 = phi float [ %7286, %__nv_exp2f.exit1417 ], [ %5661, %5575 ] + %5859 = phi float [ %7287, %__nv_exp2f.exit1417 ], [ %5662, %5575 ] + %5860 = phi float [ %7288, %__nv_exp2f.exit1417 ], [ %5663, %5575 ] + %5861 = phi float [ %7289, %__nv_exp2f.exit1417 ], [ %5664, %5575 ] + %5862 = phi float [ %7290, %__nv_exp2f.exit1417 ], [ %5665, %5575 ] + %5863 = phi float [ %7291, %__nv_exp2f.exit1417 ], [ %5666, %5575 ] + %5864 = phi float [ %7292, %__nv_exp2f.exit1417 ], [ %5667, %5575 ] + %5865 = phi float [ %7293, %__nv_exp2f.exit1417 ], [ %5668, %5575 ] + %5866 = phi float [ %7294, %__nv_exp2f.exit1417 ], [ %5669, %5575 ] + %5867 = phi float [ %7295, %__nv_exp2f.exit1417 ], [ %5670, %5575 ] + %5868 = phi float [ %7296, %__nv_exp2f.exit1417 ], [ %5671, %5575 ] + %5869 = phi float [ %7297, %__nv_exp2f.exit1417 ], [ %5672, %5575 ] + %5870 = phi float [ %7298, %__nv_exp2f.exit1417 ], [ %5673, %5575 ] + %5871 = phi float [ %7299, %__nv_exp2f.exit1417 ], [ %5674, %5575 ] + %5872 = phi float [ %7300, %__nv_exp2f.exit1417 ], [ %5675, %5575 ] + %5873 = phi float [ %7301, %__nv_exp2f.exit1417 ], [ %5676, %5575 ] + %5874 = phi float [ %7302, %__nv_exp2f.exit1417 ], [ %5677, %5575 ] + %5875 = phi float [ %7303, %__nv_exp2f.exit1417 ], [ %5678, %5575 ] + %5876 = phi float [ %7304, %__nv_exp2f.exit1417 ], [ %5679, %5575 ] + %5877 = phi float [ %7305, %__nv_exp2f.exit1417 ], [ %5680, %5575 ] + %5878 = phi float [ %7306, %__nv_exp2f.exit1417 ], [ %5681, %5575 ] + %5879 = phi float [ %7307, %__nv_exp2f.exit1417 ], [ %5682, %5575 ] + %5880 = phi float [ %7308, %__nv_exp2f.exit1417 ], [ %5683, %5575 ] + %5881 = phi float [ %7309, %__nv_exp2f.exit1417 ], [ %5684, %5575 ] + %5882 = phi float [ %7310, %__nv_exp2f.exit1417 ], [ %5685, %5575 ] + %5883 = phi float [ %7311, %__nv_exp2f.exit1417 ], [ %5686, %5575 ] + %5884 = phi float [ %7312, %__nv_exp2f.exit1417 ], [ %5687, %5575 ] + %5885 = phi float [ %7313, %__nv_exp2f.exit1417 ], [ %5688, %5575 ] + %5886 = phi float [ %7314, %__nv_exp2f.exit1417 ], [ %5689, %5575 ] + %5887 = phi float [ %7315, %__nv_exp2f.exit1417 ], [ %5690, %5575 ] + %5888 = phi float [ %7316, %__nv_exp2f.exit1417 ], [ %5691, %5575 ] + %5889 = phi float [ %7317, %__nv_exp2f.exit1417 ], [ %5692, %5575 ] + %5890 = phi float [ %7318, %__nv_exp2f.exit1417 ], [ %5693, %5575 ] + %5891 = phi float [ %7319, %__nv_exp2f.exit1417 ], [ %5694, %5575 ] + %5892 = phi float [ %7320, %__nv_exp2f.exit1417 ], [ %5695, %5575 ] + %5893 = phi float [ %7321, %__nv_exp2f.exit1417 ], [ %5696, %5575 ] + %5894 = phi float [ %7322, %__nv_exp2f.exit1417 ], [ %5697, %5575 ] + %5895 = phi float [ %7323, %__nv_exp2f.exit1417 ], [ %5698, %5575 ] + %5896 = phi float [ %7324, %__nv_exp2f.exit1417 ], [ %5699, %5575 ] + %5897 = phi float [ %7325, %__nv_exp2f.exit1417 ], [ %5700, %5575 ] + %5898 = phi float [ %7326, %__nv_exp2f.exit1417 ], [ %5701, %5575 ] + %5899 = phi float [ %7327, %__nv_exp2f.exit1417 ], [ %5702, %5575 ] + %5900 = phi float [ %7328, %__nv_exp2f.exit1417 ], [ %5703, %5575 ] + %5901 = phi float [ %8121, %__nv_exp2f.exit1417 ], [ %5576, %5575 ] + %5902 = phi float [ %8122, %__nv_exp2f.exit1417 ], [ %5577, %5575 ] + %5903 = phi float [ %8123, %__nv_exp2f.exit1417 ], [ %5578, %5575 ] + %5904 = phi float [ %8124, %__nv_exp2f.exit1417 ], [ %5579, %5575 ] + %5905 = phi float [ %8125, %__nv_exp2f.exit1417 ], [ %5580, %5575 ] + %5906 = phi float [ %8126, %__nv_exp2f.exit1417 ], [ %5581, %5575 ] + %5907 = phi float [ %8127, %__nv_exp2f.exit1417 ], [ %5582, %5575 ] + %5908 = phi float [ %8128, %__nv_exp2f.exit1417 ], [ %5583, %5575 ] + %5909 = phi float [ %8129, %__nv_exp2f.exit1417 ], [ %5584, %5575 ] + %5910 = phi float [ %8130, %__nv_exp2f.exit1417 ], [ %5585, %5575 ] + %5911 = phi float [ %8131, %__nv_exp2f.exit1417 ], [ %5586, %5575 ] + %5912 = phi float [ %8132, %__nv_exp2f.exit1417 ], [ %5587, %5575 ] + %5913 = phi float [ %8133, %__nv_exp2f.exit1417 ], [ %5588, %5575 ] + %5914 = phi float [ %8134, %__nv_exp2f.exit1417 ], [ %5589, %5575 ] + %5915 = phi float [ %8135, %__nv_exp2f.exit1417 ], [ %5590, %5575 ] + %5916 = phi float [ %8136, %__nv_exp2f.exit1417 ], [ %5591, %5575 ] + %5917 = phi float [ %8137, %__nv_exp2f.exit1417 ], [ %5592, %5575 ] + %5918 = phi float [ %8138, %__nv_exp2f.exit1417 ], [ %5593, %5575 ] + %5919 = phi float [ %8139, %__nv_exp2f.exit1417 ], [ %5594, %5575 ] + %5920 = phi float [ %8140, %__nv_exp2f.exit1417 ], [ %5595, %5575 ] + %5921 = phi float [ %8141, %__nv_exp2f.exit1417 ], [ %5596, %5575 ] + %5922 = phi float [ %8142, %__nv_exp2f.exit1417 ], [ %5597, %5575 ] + %5923 = phi float [ %8143, %__nv_exp2f.exit1417 ], [ %5598, %5575 ] + %5924 = phi float [ %8144, %__nv_exp2f.exit1417 ], [ %5599, %5575 ] + %5925 = phi float [ %8145, %__nv_exp2f.exit1417 ], [ %5600, %5575 ] + %5926 = phi float [ %8146, %__nv_exp2f.exit1417 ], [ %5601, %5575 ] + %5927 = phi float [ %8147, %__nv_exp2f.exit1417 ], [ %5602, %5575 ] + %5928 = phi float [ %8148, %__nv_exp2f.exit1417 ], [ %5603, %5575 ] + %5929 = phi float [ %8149, %__nv_exp2f.exit1417 ], [ %5604, %5575 ] + %5930 = phi float [ %8150, %__nv_exp2f.exit1417 ], [ %5605, %5575 ] + %5931 = phi float [ %8151, %__nv_exp2f.exit1417 ], [ %5606, %5575 ] + %5932 = phi float [ %8152, %__nv_exp2f.exit1417 ], [ %5607, %5575 ] + %5933 = phi float [ %8153, %__nv_exp2f.exit1417 ], [ %5608, %5575 ] + %5934 = phi float [ %8154, %__nv_exp2f.exit1417 ], [ %5609, %5575 ] + %5935 = phi float [ %8155, %__nv_exp2f.exit1417 ], [ %5610, %5575 ] + %5936 = phi float [ %8156, %__nv_exp2f.exit1417 ], [ %5611, %5575 ] + %5937 = phi float [ %8157, %__nv_exp2f.exit1417 ], [ %5612, %5575 ] + %5938 = phi float [ %8158, %__nv_exp2f.exit1417 ], [ %5613, %5575 ] + %5939 = phi float [ %8159, %__nv_exp2f.exit1417 ], [ %5614, %5575 ] + %5940 = phi float [ %8160, %__nv_exp2f.exit1417 ], [ %5615, %5575 ] + %5941 = phi float [ %8161, %__nv_exp2f.exit1417 ], [ %5616, %5575 ] + %5942 = phi float [ %8162, %__nv_exp2f.exit1417 ], [ %5617, %5575 ] + %5943 = phi float [ %8163, %__nv_exp2f.exit1417 ], [ %5618, %5575 ] + %5944 = phi float [ %8164, %__nv_exp2f.exit1417 ], [ %5619, %5575 ] + %5945 = phi float [ %8165, %__nv_exp2f.exit1417 ], [ %5620, %5575 ] + %5946 = phi float [ %8166, %__nv_exp2f.exit1417 ], [ %5621, %5575 ] + %5947 = phi float [ %8167, %__nv_exp2f.exit1417 ], [ %5622, %5575 ] + %5948 = phi float [ %8168, %__nv_exp2f.exit1417 ], [ %5623, %5575 ] + %5949 = phi float [ %8169, %__nv_exp2f.exit1417 ], [ %5624, %5575 ] + %5950 = phi float [ %8170, %__nv_exp2f.exit1417 ], [ %5625, %5575 ] + %5951 = phi float [ %8171, %__nv_exp2f.exit1417 ], [ %5626, %5575 ] + %5952 = phi float [ %8172, %__nv_exp2f.exit1417 ], [ %5627, %5575 ] + %5953 = phi float [ %8173, %__nv_exp2f.exit1417 ], [ %5628, %5575 ] + %5954 = phi float [ %8174, %__nv_exp2f.exit1417 ], [ %5629, %5575 ] + %5955 = phi float [ %8175, %__nv_exp2f.exit1417 ], [ %5630, %5575 ] + %5956 = phi float [ %8176, %__nv_exp2f.exit1417 ], [ %5631, %5575 ] + %5957 = phi float [ %8177, %__nv_exp2f.exit1417 ], [ %5632, %5575 ] + %5958 = phi float [ %8178, %__nv_exp2f.exit1417 ], [ %5633, %5575 ] + %5959 = phi float [ %8179, %__nv_exp2f.exit1417 ], [ %5634, %5575 ] + %5960 = phi float [ %8180, %__nv_exp2f.exit1417 ], [ %5635, %5575 ] + %5961 = phi float [ %8181, %__nv_exp2f.exit1417 ], [ %5636, %5575 ] + %5962 = phi float [ %8182, %__nv_exp2f.exit1417 ], [ %5637, %5575 ] + %5963 = phi float [ %8183, %__nv_exp2f.exit1417 ], [ %5638, %5575 ] + %5964 = phi float [ %8184, %__nv_exp2f.exit1417 ], [ %5639, %5575 ] + %5965 = phi i32 [ %8188, %__nv_exp2f.exit1417 ], [ 0, %5575 ] + %5966 = phi <16 x i32> [ %8187, %__nv_exp2f.exit1417 ], [ %5030, %5575 ] + %5967 = icmp slt i32 %5965, %5384, !dbg !274 + %5968 = icmp slt i32 %5965, %5385, !dbg !274 + %5969 = add i32 %5825, 1, !dbg !274 + %5970 = icmp sgt i32 %5969, 1, !dbg !274 + %5971 = select i1 %5970, i32 0, i32 %5969, !dbg !274 + %5972 = add i32 %5827, 1, !dbg !274 + %5973 = icmp sgt i32 %5972, 2, !dbg !274 + %5974 = select i1 %5973, i32 0, i32 %5972, !dbg !274 + %5975 = icmp slt i32 %5809, %17, !dbg !275 + %5976 = icmp slt i32 %5810, %17, !dbg !275 + %5977 = icmp slt i32 %5811, %17, !dbg !275 + %5978 = icmp slt i32 %5812, %17, !dbg !275 + %5979 = icmp slt i32 %5813, %17, !dbg !275 + %5980 = icmp slt i32 %5814, %17, !dbg !275 + %5981 = icmp slt i32 %5815, %17, !dbg !275 + %5982 = icmp slt i32 %5816, %17, !dbg !275 + %5983 = icmp slt i32 %5817, %17, !dbg !275 + %5984 = icmp slt i32 %5818, %17, !dbg !275 + %5985 = icmp slt i32 %5819, %17, !dbg !275 + %5986 = icmp slt i32 %5820, %17, !dbg !275 + %5987 = icmp slt i32 %5821, %17, !dbg !275 + %5988 = icmp slt i32 %5822, %17, !dbg !275 + %5989 = icmp slt i32 %5823, %17, !dbg !275 + %5990 = icmp slt i32 %5824, %17, !dbg !275 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !266 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !266 + %5991 = shl i32 %5974, 13, !dbg !266 + %5992 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %5991, !dbg !266 + %5993 = shl i32 %5971, 6, !dbg !268 + %5994 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %5993, !dbg !268 + %5995 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5179, !dbg !268 + %5996 = load float, ptr addrspace(3) %5995, align 8, !dbg !268 + %5997 = getelementptr inbounds nuw i8, ptr addrspace(3) %5995, i32 4, !dbg !268 + %5998 = load float, ptr addrspace(3) %5997, align 4, !dbg !268 + %5999 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5185, !dbg !268 + %6000 = load float, ptr addrspace(3) %5999, align 8, !dbg !268 + %6001 = getelementptr inbounds nuw i8, ptr addrspace(3) %5999, i32 4, !dbg !268 + %6002 = load float, ptr addrspace(3) %6001, align 4, !dbg !268 + %6003 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5191, !dbg !268 + %6004 = load float, ptr addrspace(3) %6003, align 8, !dbg !268 + %6005 = getelementptr inbounds nuw i8, ptr addrspace(3) %6003, i32 4, !dbg !268 + %6006 = load float, ptr addrspace(3) %6005, align 4, !dbg !268 + %6007 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5197, !dbg !268 + %6008 = load float, ptr addrspace(3) %6007, align 8, !dbg !268 + %6009 = getelementptr inbounds nuw i8, ptr addrspace(3) %6007, i32 4, !dbg !268 + %6010 = load float, ptr addrspace(3) %6009, align 4, !dbg !268 + %6011 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5203, !dbg !268 + %6012 = load float, ptr addrspace(3) %6011, align 8, !dbg !268 + %6013 = getelementptr inbounds nuw i8, ptr addrspace(3) %6011, i32 4, !dbg !268 + %6014 = load float, ptr addrspace(3) %6013, align 4, !dbg !268 + %6015 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5209, !dbg !268 + %6016 = load float, ptr addrspace(3) %6015, align 8, !dbg !268 + %6017 = getelementptr inbounds nuw i8, ptr addrspace(3) %6015, i32 4, !dbg !268 + %6018 = load float, ptr addrspace(3) %6017, align 4, !dbg !268 + %6019 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5215, !dbg !268 + %6020 = load float, ptr addrspace(3) %6019, align 8, !dbg !268 + %6021 = getelementptr inbounds nuw i8, ptr addrspace(3) %6019, i32 4, !dbg !268 + %6022 = load float, ptr addrspace(3) %6021, align 4, !dbg !268 + %6023 = getelementptr inbounds nuw i8, ptr addrspace(3) %5994, i32 %5221, !dbg !268 + %6024 = load float, ptr addrspace(3) %6023, align 8, !dbg !268 + %6025 = getelementptr inbounds nuw i8, ptr addrspace(3) %6023, i32 4, !dbg !268 + %6026 = load float, ptr addrspace(3) %6025, align 4, !dbg !268 + %6027 = fcmp oeq float %5996, 0xFFF0000000000000, !dbg !276 + %6028 = fcmp oeq float %5998, 0xFFF0000000000000, !dbg !276 + %6029 = fcmp oeq float %6000, 0xFFF0000000000000, !dbg !276 + %6030 = fcmp oeq float %6002, 0xFFF0000000000000, !dbg !276 + %6031 = fcmp oeq float %6004, 0xFFF0000000000000, !dbg !276 + %6032 = fcmp oeq float %6006, 0xFFF0000000000000, !dbg !276 + %6033 = fcmp oeq float %6008, 0xFFF0000000000000, !dbg !276 + %6034 = fcmp oeq float %6010, 0xFFF0000000000000, !dbg !276 + %6035 = fcmp oeq float %6012, 0xFFF0000000000000, !dbg !276 + %6036 = fcmp oeq float %6014, 0xFFF0000000000000, !dbg !276 + %6037 = fcmp oeq float %6016, 0xFFF0000000000000, !dbg !276 + %6038 = fcmp oeq float %6018, 0xFFF0000000000000, !dbg !276 + %6039 = fcmp oeq float %6020, 0xFFF0000000000000, !dbg !276 + %6040 = fcmp oeq float %6022, 0xFFF0000000000000, !dbg !276 + %6041 = fcmp oeq float %6024, 0xFFF0000000000000, !dbg !276 + %6042 = fcmp oeq float %6026, 0xFFF0000000000000, !dbg !276 + %6043 = select i1 %6027, float 0.000000e+00, float %5996, !dbg !277 + %6044 = select i1 %6028, float 0.000000e+00, float %5998, !dbg !277 + %6045 = select i1 %6029, float 0.000000e+00, float %6000, !dbg !277 + %6046 = select i1 %6030, float 0.000000e+00, float %6002, !dbg !277 + %6047 = select i1 %6031, float 0.000000e+00, float %6004, !dbg !277 + %6048 = select i1 %6032, float 0.000000e+00, float %6006, !dbg !277 + %6049 = select i1 %6033, float 0.000000e+00, float %6008, !dbg !277 + %6050 = select i1 %6034, float 0.000000e+00, float %6010, !dbg !277 + %6051 = select i1 %6035, float 0.000000e+00, float %6012, !dbg !277 + %6052 = select i1 %6036, float 0.000000e+00, float %6014, !dbg !277 + %6053 = select i1 %6037, float 0.000000e+00, float %6016, !dbg !277 + %6054 = select i1 %6038, float 0.000000e+00, float %6018, !dbg !277 + %6055 = select i1 %6039, float 0.000000e+00, float %6020, !dbg !277 + %6056 = select i1 %6040, float 0.000000e+00, float %6022, !dbg !277 + %6057 = select i1 %6041, float 0.000000e+00, float %6024, !dbg !277 + %6058 = select i1 %6042, float 0.000000e+00, float %6026, !dbg !277 + %6059 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %45, i32 0, i32 31), !dbg !248 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !248 + %6060 = shl i32 %6059, 11, !dbg !248 + %6061 = and i32 %6060, 8192, !dbg !248 + %6062 = add i32 %6061, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6063 = lshr exact i32 %6062, 4, !dbg !248 + %6064 = and i32 %6063, 16383, !dbg !248 + %6065 = zext nneg i32 %6064 to i64, !dbg !248 + %6066 = or disjoint i64 %6065, 4611686293372403712, !dbg !248 + %6067 = ptrtoint ptr addrspace(3) %5992 to i32, !dbg !248 + %6068 = lshr exact i32 %6067, 4, !dbg !248 + %6069 = and i32 %6068, 16383, !dbg !248 + %6070 = zext nneg i32 %6069 to i64, !dbg !248 + %6071 = or disjoint i64 %6070, 4611686293338849280, !dbg !248 + %6072 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %6066, i64 %6071) #3, !dbg !248 + %6073 = or disjoint i32 %6061, 32, !dbg !248 + %6074 = add i32 %6073, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6075 = lshr exact i32 %6074, 4, !dbg !248 + %6076 = and i32 %6075, 16383, !dbg !248 + %6077 = zext nneg i32 %6076 to i64, !dbg !248 + %6078 = or disjoint i64 %6077, 4611686293372403712, !dbg !248 + %6079 = add i32 %6067, 32, !dbg !248 + %6080 = lshr exact i32 %6079, 4, !dbg !248 + %6081 = and i32 %6080, 16383, !dbg !248 + %6082 = zext nneg i32 %6081 to i64, !dbg !248 + %6083 = or disjoint i64 %6082, 4611686293338849280, !dbg !248 + %6084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 0, !dbg !248 + %6085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 1, !dbg !248 + %6086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 2, !dbg !248 + %6087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 3, !dbg !248 + %6088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 4, !dbg !248 + %6089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 5, !dbg !248 + %6090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 6, !dbg !248 + %6091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 7, !dbg !248 + %6092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 8, !dbg !248 + %6093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 9, !dbg !248 + %6094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 10, !dbg !248 + %6095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 11, !dbg !248 + %6096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 12, !dbg !248 + %6097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 13, !dbg !248 + %6098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 14, !dbg !248 + %6099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 15, !dbg !248 + %6100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 16, !dbg !248 + %6101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 17, !dbg !248 + %6102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 18, !dbg !248 + %6103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 19, !dbg !248 + %6104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 20, !dbg !248 + %6105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 21, !dbg !248 + %6106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 22, !dbg !248 + %6107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 23, !dbg !248 + %6108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 24, !dbg !248 + %6109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 25, !dbg !248 + %6110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 26, !dbg !248 + %6111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 27, !dbg !248 + %6112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 28, !dbg !248 + %6113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 29, !dbg !248 + %6114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 30, !dbg !248 + %6115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6072, 31, !dbg !248 + %6116 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6084, float %6085, float %6086, float %6087, float %6088, float %6089, float %6090, float %6091, float %6092, float %6093, float %6094, float %6095, float %6096, float %6097, float %6098, float %6099, float %6100, float %6101, float %6102, float %6103, float %6104, float %6105, float %6106, float %6107, float %6108, float %6109, float %6110, float %6111, float %6112, float %6113, float %6114, float %6115, i64 %6078, i64 %6083, i1 true) #3, !dbg !248 + %6117 = or disjoint i32 %6061, 64, !dbg !248 + %6118 = add i32 %6117, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6119 = lshr exact i32 %6118, 4, !dbg !248 + %6120 = and i32 %6119, 16383, !dbg !248 + %6121 = zext nneg i32 %6120 to i64, !dbg !248 + %6122 = or disjoint i64 %6121, 4611686293372403712, !dbg !248 + %6123 = add i32 %6067, 64, !dbg !248 + %6124 = lshr exact i32 %6123, 4, !dbg !248 + %6125 = and i32 %6124, 16383, !dbg !248 + %6126 = zext nneg i32 %6125 to i64, !dbg !248 + %6127 = or disjoint i64 %6126, 4611686293338849280, !dbg !248 + %6128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 0, !dbg !248 + %6129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 1, !dbg !248 + %6130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 2, !dbg !248 + %6131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 3, !dbg !248 + %6132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 4, !dbg !248 + %6133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 5, !dbg !248 + %6134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 6, !dbg !248 + %6135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 7, !dbg !248 + %6136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 8, !dbg !248 + %6137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 9, !dbg !248 + %6138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 10, !dbg !248 + %6139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 11, !dbg !248 + %6140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 12, !dbg !248 + %6141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 13, !dbg !248 + %6142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 14, !dbg !248 + %6143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 15, !dbg !248 + %6144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 16, !dbg !248 + %6145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 17, !dbg !248 + %6146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 18, !dbg !248 + %6147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 19, !dbg !248 + %6148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 20, !dbg !248 + %6149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 21, !dbg !248 + %6150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 22, !dbg !248 + %6151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 23, !dbg !248 + %6152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 24, !dbg !248 + %6153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 25, !dbg !248 + %6154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 26, !dbg !248 + %6155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 27, !dbg !248 + %6156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 28, !dbg !248 + %6157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 29, !dbg !248 + %6158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 30, !dbg !248 + %6159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6116, 31, !dbg !248 + %6160 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6128, float %6129, float %6130, float %6131, float %6132, float %6133, float %6134, float %6135, float %6136, float %6137, float %6138, float %6139, float %6140, float %6141, float %6142, float %6143, float %6144, float %6145, float %6146, float %6147, float %6148, float %6149, float %6150, float %6151, float %6152, float %6153, float %6154, float %6155, float %6156, float %6157, float %6158, float %6159, i64 %6122, i64 %6127, i1 true) #3, !dbg !248 + %6161 = or disjoint i32 %6061, 96, !dbg !248 + %6162 = add i32 %6161, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6163 = lshr exact i32 %6162, 4, !dbg !248 + %6164 = and i32 %6163, 16383, !dbg !248 + %6165 = zext nneg i32 %6164 to i64, !dbg !248 + %6166 = or disjoint i64 %6165, 4611686293372403712, !dbg !248 + %6167 = add i32 %6067, 96, !dbg !248 + %6168 = lshr exact i32 %6167, 4, !dbg !248 + %6169 = and i32 %6168, 16383, !dbg !248 + %6170 = zext nneg i32 %6169 to i64, !dbg !248 + %6171 = or disjoint i64 %6170, 4611686293338849280, !dbg !248 + %6172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 0, !dbg !248 + %6173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 1, !dbg !248 + %6174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 2, !dbg !248 + %6175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 3, !dbg !248 + %6176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 4, !dbg !248 + %6177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 5, !dbg !248 + %6178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 6, !dbg !248 + %6179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 7, !dbg !248 + %6180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 8, !dbg !248 + %6181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 9, !dbg !248 + %6182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 10, !dbg !248 + %6183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 11, !dbg !248 + %6184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 12, !dbg !248 + %6185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 13, !dbg !248 + %6186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 14, !dbg !248 + %6187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 15, !dbg !248 + %6188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 16, !dbg !248 + %6189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 17, !dbg !248 + %6190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 18, !dbg !248 + %6191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 19, !dbg !248 + %6192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 20, !dbg !248 + %6193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 21, !dbg !248 + %6194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 22, !dbg !248 + %6195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 23, !dbg !248 + %6196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 24, !dbg !248 + %6197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 25, !dbg !248 + %6198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 26, !dbg !248 + %6199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 27, !dbg !248 + %6200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 28, !dbg !248 + %6201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 29, !dbg !248 + %6202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 30, !dbg !248 + %6203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6160, 31, !dbg !248 + %6204 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6172, float %6173, float %6174, float %6175, float %6176, float %6177, float %6178, float %6179, float %6180, float %6181, float %6182, float %6183, float %6184, float %6185, float %6186, float %6187, float %6188, float %6189, float %6190, float %6191, float %6192, float %6193, float %6194, float %6195, float %6196, float %6197, float %6198, float %6199, float %6200, float %6201, float %6202, float %6203, i64 %6166, i64 %6171, i1 true) #3, !dbg !248 + %6205 = or disjoint i32 %6061, 16384, !dbg !248 + %6206 = add i32 %6205, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6207 = lshr exact i32 %6206, 4, !dbg !248 + %6208 = and i32 %6207, 16383, !dbg !248 + %6209 = zext nneg i32 %6208 to i64, !dbg !248 + %6210 = or disjoint i64 %6209, 4611686293372403712, !dbg !248 + %6211 = add i32 %6067, 8192, !dbg !248 + %6212 = lshr exact i32 %6211, 4, !dbg !248 + %6213 = and i32 %6212, 16383, !dbg !248 + %6214 = zext nneg i32 %6213 to i64, !dbg !248 + %6215 = or disjoint i64 %6214, 4611686293338849280, !dbg !248 + %6216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 0, !dbg !248 + %6217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 1, !dbg !248 + %6218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 2, !dbg !248 + %6219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 3, !dbg !248 + %6220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 4, !dbg !248 + %6221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 5, !dbg !248 + %6222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 6, !dbg !248 + %6223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 7, !dbg !248 + %6224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 8, !dbg !248 + %6225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 9, !dbg !248 + %6226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 10, !dbg !248 + %6227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 11, !dbg !248 + %6228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 12, !dbg !248 + %6229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 13, !dbg !248 + %6230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 14, !dbg !248 + %6231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 15, !dbg !248 + %6232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 16, !dbg !248 + %6233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 17, !dbg !248 + %6234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 18, !dbg !248 + %6235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 19, !dbg !248 + %6236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 20, !dbg !248 + %6237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 21, !dbg !248 + %6238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 22, !dbg !248 + %6239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 23, !dbg !248 + %6240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 24, !dbg !248 + %6241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 25, !dbg !248 + %6242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 26, !dbg !248 + %6243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 27, !dbg !248 + %6244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 28, !dbg !248 + %6245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 29, !dbg !248 + %6246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 30, !dbg !248 + %6247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6204, 31, !dbg !248 + %6248 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6216, float %6217, float %6218, float %6219, float %6220, float %6221, float %6222, float %6223, float %6224, float %6225, float %6226, float %6227, float %6228, float %6229, float %6230, float %6231, float %6232, float %6233, float %6234, float %6235, float %6236, float %6237, float %6238, float %6239, float %6240, float %6241, float %6242, float %6243, float %6244, float %6245, float %6246, float %6247, i64 %6210, i64 %6215, i1 true) #3, !dbg !248 + %6249 = or disjoint i32 %6061, 16416, !dbg !248 + %6250 = add i32 %6249, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6251 = lshr exact i32 %6250, 4, !dbg !248 + %6252 = and i32 %6251, 16383, !dbg !248 + %6253 = zext nneg i32 %6252 to i64, !dbg !248 + %6254 = or disjoint i64 %6253, 4611686293372403712, !dbg !248 + %6255 = add i32 %6067, 8224, !dbg !248 + %6256 = lshr exact i32 %6255, 4, !dbg !248 + %6257 = and i32 %6256, 16383, !dbg !248 + %6258 = zext nneg i32 %6257 to i64, !dbg !248 + %6259 = or disjoint i64 %6258, 4611686293338849280, !dbg !248 + %6260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 0, !dbg !248 + %6261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 1, !dbg !248 + %6262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 2, !dbg !248 + %6263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 3, !dbg !248 + %6264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 4, !dbg !248 + %6265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 5, !dbg !248 + %6266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 6, !dbg !248 + %6267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 7, !dbg !248 + %6268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 8, !dbg !248 + %6269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 9, !dbg !248 + %6270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 10, !dbg !248 + %6271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 11, !dbg !248 + %6272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 12, !dbg !248 + %6273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 13, !dbg !248 + %6274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 14, !dbg !248 + %6275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 15, !dbg !248 + %6276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 16, !dbg !248 + %6277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 17, !dbg !248 + %6278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 18, !dbg !248 + %6279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 19, !dbg !248 + %6280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 20, !dbg !248 + %6281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 21, !dbg !248 + %6282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 22, !dbg !248 + %6283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 23, !dbg !248 + %6284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 24, !dbg !248 + %6285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 25, !dbg !248 + %6286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 26, !dbg !248 + %6287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 27, !dbg !248 + %6288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 28, !dbg !248 + %6289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 29, !dbg !248 + %6290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 30, !dbg !248 + %6291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6248, 31, !dbg !248 + %6292 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6260, float %6261, float %6262, float %6263, float %6264, float %6265, float %6266, float %6267, float %6268, float %6269, float %6270, float %6271, float %6272, float %6273, float %6274, float %6275, float %6276, float %6277, float %6278, float %6279, float %6280, float %6281, float %6282, float %6283, float %6284, float %6285, float %6286, float %6287, float %6288, float %6289, float %6290, float %6291, i64 %6254, i64 %6259, i1 true) #3, !dbg !248 + %6293 = or disjoint i32 %6061, 16448, !dbg !248 + %6294 = add i32 %6293, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6295 = lshr exact i32 %6294, 4, !dbg !248 + %6296 = and i32 %6295, 16383, !dbg !248 + %6297 = zext nneg i32 %6296 to i64, !dbg !248 + %6298 = or disjoint i64 %6297, 4611686293372403712, !dbg !248 + %6299 = add i32 %6067, 8256, !dbg !248 + %6300 = lshr exact i32 %6299, 4, !dbg !248 + %6301 = and i32 %6300, 16383, !dbg !248 + %6302 = zext nneg i32 %6301 to i64, !dbg !248 + %6303 = or disjoint i64 %6302, 4611686293338849280, !dbg !248 + %6304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 0, !dbg !248 + %6305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 1, !dbg !248 + %6306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 2, !dbg !248 + %6307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 3, !dbg !248 + %6308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 4, !dbg !248 + %6309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 5, !dbg !248 + %6310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 6, !dbg !248 + %6311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 7, !dbg !248 + %6312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 8, !dbg !248 + %6313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 9, !dbg !248 + %6314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 10, !dbg !248 + %6315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 11, !dbg !248 + %6316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 12, !dbg !248 + %6317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 13, !dbg !248 + %6318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 14, !dbg !248 + %6319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 15, !dbg !248 + %6320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 16, !dbg !248 + %6321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 17, !dbg !248 + %6322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 18, !dbg !248 + %6323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 19, !dbg !248 + %6324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 20, !dbg !248 + %6325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 21, !dbg !248 + %6326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 22, !dbg !248 + %6327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 23, !dbg !248 + %6328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 24, !dbg !248 + %6329 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 25, !dbg !248 + %6330 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 26, !dbg !248 + %6331 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 27, !dbg !248 + %6332 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 28, !dbg !248 + %6333 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 29, !dbg !248 + %6334 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 30, !dbg !248 + %6335 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6292, 31, !dbg !248 + %6336 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6304, float %6305, float %6306, float %6307, float %6308, float %6309, float %6310, float %6311, float %6312, float %6313, float %6314, float %6315, float %6316, float %6317, float %6318, float %6319, float %6320, float %6321, float %6322, float %6323, float %6324, float %6325, float %6326, float %6327, float %6328, float %6329, float %6330, float %6331, float %6332, float %6333, float %6334, float %6335, i64 %6298, i64 %6303, i1 true) #3, !dbg !248 + %6337 = or disjoint i32 %6061, 16480, !dbg !248 + %6338 = add i32 %6337, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !248 + %6339 = lshr exact i32 %6338, 4, !dbg !248 + %6340 = and i32 %6339, 16383, !dbg !248 + %6341 = zext nneg i32 %6340 to i64, !dbg !248 + %6342 = or disjoint i64 %6341, 4611686293372403712, !dbg !248 + %6343 = add i32 %6067, 8288, !dbg !248 + %6344 = lshr exact i32 %6343, 4, !dbg !248 + %6345 = and i32 %6344, 16383, !dbg !248 + %6346 = zext nneg i32 %6345 to i64, !dbg !248 + %6347 = or disjoint i64 %6346, 4611686293338849280, !dbg !248 + %6348 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 0, !dbg !248 + %6349 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 1, !dbg !248 + %6350 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 2, !dbg !248 + %6351 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 3, !dbg !248 + %6352 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 4, !dbg !248 + %6353 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 5, !dbg !248 + %6354 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 6, !dbg !248 + %6355 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 7, !dbg !248 + %6356 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 8, !dbg !248 + %6357 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 9, !dbg !248 + %6358 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 10, !dbg !248 + %6359 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 11, !dbg !248 + %6360 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 12, !dbg !248 + %6361 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 13, !dbg !248 + %6362 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 14, !dbg !248 + %6363 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 15, !dbg !248 + %6364 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 16, !dbg !248 + %6365 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 17, !dbg !248 + %6366 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 18, !dbg !248 + %6367 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 19, !dbg !248 + %6368 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 20, !dbg !248 + %6369 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 21, !dbg !248 + %6370 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 22, !dbg !248 + %6371 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 23, !dbg !248 + %6372 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 24, !dbg !248 + %6373 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 25, !dbg !248 + %6374 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 26, !dbg !248 + %6375 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 27, !dbg !248 + %6376 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 28, !dbg !248 + %6377 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 29, !dbg !248 + %6378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 30, !dbg !248 + %6379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6336, 31, !dbg !248 + %6380 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %6348, float %6349, float %6350, float %6351, float %6352, float %6353, float %6354, float %6355, float %6356, float %6357, float %6358, float %6359, float %6360, float %6361, float %6362, float %6363, float %6364, float %6365, float %6366, float %6367, float %6368, float %6369, float %6370, float %6371, float %6372, float %6373, float %6374, float %6375, float %6376, float %6377, float %6378, float %6379, i64 %6342, i64 %6347, i1 true) #3, !dbg !248 + %6381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 0, !dbg !248 + %6382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 1, !dbg !248 + %6383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 2, !dbg !248 + %6384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 3, !dbg !248 + %6385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 4, !dbg !248 + %6386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 5, !dbg !248 + %6387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 6, !dbg !248 + %6388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 7, !dbg !248 + %6389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 8, !dbg !248 + %6390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 9, !dbg !248 + %6391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 10, !dbg !248 + %6392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 11, !dbg !248 + %6393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 12, !dbg !248 + %6394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 13, !dbg !248 + %6395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 14, !dbg !248 + %6396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 15, !dbg !248 + %6397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 16, !dbg !248 + %6398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 17, !dbg !248 + %6399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 18, !dbg !248 + %6400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 19, !dbg !248 + %6401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 20, !dbg !248 + %6402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 21, !dbg !248 + %6403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 22, !dbg !248 + %6404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 23, !dbg !248 + %6405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 24, !dbg !248 + %6406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 25, !dbg !248 + %6407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 26, !dbg !248 + %6408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 27, !dbg !248 + %6409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 28, !dbg !248 + %6410 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 29, !dbg !248 + %6411 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 30, !dbg !248 + %6412 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %6380, 31, !dbg !248 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !248 + %6413 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %6381, float %6382, float %6383, float %6384, float %6385, float %6386, float %6387, float %6388, float %6389, float %6390, float %6391, float %6392, float %6393, float %6394, float %6395, float %6396, float %6397, float %6398, float %6399, float %6400, float %6401, float %6402, float %6403, float %6404, float %6405, float %6406, float %6407, float %6408, float %6409, float %6410, float %6411, float %6412, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %5992, i32 0, i32 0) #3, !dbg !248 + %6414 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 0, !dbg !248 + %6415 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 1, !dbg !248 + %6416 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 2, !dbg !248 + %6417 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 3, !dbg !248 + %6418 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 4, !dbg !248 + %6419 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 5, !dbg !248 + %6420 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 6, !dbg !248 + %6421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 7, !dbg !248 + %6422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 8, !dbg !248 + %6423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 9, !dbg !248 + %6424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 10, !dbg !248 + %6425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 11, !dbg !248 + %6426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 12, !dbg !248 + %6427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 13, !dbg !248 + %6428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 14, !dbg !248 + %6429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 15, !dbg !248 + %6430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 16, !dbg !248 + %6431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 17, !dbg !248 + %6432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 18, !dbg !248 + %6433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 19, !dbg !248 + %6434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 20, !dbg !248 + %6435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 21, !dbg !248 + %6436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 22, !dbg !248 + %6437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 23, !dbg !248 + %6438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 24, !dbg !248 + %6439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 25, !dbg !248 + %6440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 26, !dbg !248 + %6441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 27, !dbg !248 + %6442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 28, !dbg !248 + %6443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 29, !dbg !248 + %6444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 30, !dbg !248 + %6445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %6413, 31, !dbg !248 + %6446 = fmul float %6414, 0x3FB6A09E60000000, !dbg !278 + %6447 = fmul float %6415, 0x3FB6A09E60000000, !dbg !278 + %6448 = fmul float %6416, 0x3FB6A09E60000000, !dbg !278 + %6449 = fmul float %6417, 0x3FB6A09E60000000, !dbg !278 + %6450 = fmul float %6418, 0x3FB6A09E60000000, !dbg !278 + %6451 = fmul float %6419, 0x3FB6A09E60000000, !dbg !278 + %6452 = fmul float %6420, 0x3FB6A09E60000000, !dbg !278 + %6453 = fmul float %6421, 0x3FB6A09E60000000, !dbg !278 + %6454 = fmul float %6422, 0x3FB6A09E60000000, !dbg !278 + %6455 = fmul float %6423, 0x3FB6A09E60000000, !dbg !278 + %6456 = fmul float %6424, 0x3FB6A09E60000000, !dbg !278 + %6457 = fmul float %6425, 0x3FB6A09E60000000, !dbg !278 + %6458 = fmul float %6426, 0x3FB6A09E60000000, !dbg !278 + %6459 = fmul float %6427, 0x3FB6A09E60000000, !dbg !278 + %6460 = fmul float %6428, 0x3FB6A09E60000000, !dbg !278 + %6461 = fmul float %6429, 0x3FB6A09E60000000, !dbg !278 + %6462 = fmul float %6430, 0x3FB6A09E60000000, !dbg !278 + %6463 = fmul float %6431, 0x3FB6A09E60000000, !dbg !278 + %6464 = fmul float %6432, 0x3FB6A09E60000000, !dbg !278 + %6465 = fmul float %6433, 0x3FB6A09E60000000, !dbg !278 + %6466 = fmul float %6434, 0x3FB6A09E60000000, !dbg !278 + %6467 = fmul float %6435, 0x3FB6A09E60000000, !dbg !278 + %6468 = fmul float %6436, 0x3FB6A09E60000000, !dbg !278 + %6469 = fmul float %6437, 0x3FB6A09E60000000, !dbg !278 + %6470 = fmul float %6438, 0x3FB6A09E60000000, !dbg !278 + %6471 = fmul float %6439, 0x3FB6A09E60000000, !dbg !278 + %6472 = fmul float %6440, 0x3FB6A09E60000000, !dbg !278 + %6473 = fmul float %6441, 0x3FB6A09E60000000, !dbg !278 + %6474 = fmul float %6442, 0x3FB6A09E60000000, !dbg !278 + %6475 = fmul float %6443, 0x3FB6A09E60000000, !dbg !278 + %6476 = fmul float %6444, 0x3FB6A09E60000000, !dbg !278 + %6477 = fmul float %6445, 0x3FB6A09E60000000, !dbg !278 + %6478 = srem <16 x i32> %5966, %5572, !dbg !235 + %6479 = icmp slt <16 x i32> %6478, zeroinitializer, !dbg !279 + %6480 = extractelement <16 x i32> %6478, i64 15, !dbg !280 + %6481 = icmp sle i32 %5573, %6480, !dbg !281 + %6482 = extractelement <16 x i32> %6478, i64 14, !dbg !280 + %6483 = icmp sle i32 %5573, %6482, !dbg !281 + %6484 = icmp sle i32 %5574, %6480, !dbg !281 + %6485 = icmp sle i32 %5574, %6482, !dbg !281 + %6486 = extractelement <16 x i32> %6478, i64 13, !dbg !280 + %6487 = icmp sle i32 %5573, %6486, !dbg !281 + %6488 = extractelement <16 x i32> %6478, i64 12, !dbg !280 + %6489 = icmp sle i32 %5573, %6488, !dbg !281 + %6490 = icmp sle i32 %5574, %6486, !dbg !281 + %6491 = icmp sle i32 %5574, %6488, !dbg !281 + %6492 = extractelement <16 x i32> %6478, i64 11, !dbg !280 + %6493 = icmp sle i32 %5573, %6492, !dbg !281 + %6494 = extractelement <16 x i32> %6478, i64 10, !dbg !280 + %6495 = icmp sle i32 %5573, %6494, !dbg !281 + %6496 = icmp sle i32 %5574, %6492, !dbg !281 + %6497 = icmp sle i32 %5574, %6494, !dbg !281 + %6498 = extractelement <16 x i32> %6478, i64 9, !dbg !280 + %6499 = icmp sle i32 %5573, %6498, !dbg !281 + %6500 = extractelement <16 x i32> %6478, i64 8, !dbg !280 + %6501 = icmp sle i32 %5573, %6500, !dbg !281 + %6502 = icmp sle i32 %5574, %6498, !dbg !281 + %6503 = icmp sle i32 %5574, %6500, !dbg !281 + %6504 = extractelement <16 x i32> %6478, i64 7, !dbg !280 + %6505 = icmp sle i32 %5573, %6504, !dbg !281 + %6506 = extractelement <16 x i32> %6478, i64 6, !dbg !280 + %6507 = icmp sle i32 %5573, %6506, !dbg !281 + %6508 = icmp sle i32 %5574, %6504, !dbg !281 + %6509 = icmp sle i32 %5574, %6506, !dbg !281 + %6510 = extractelement <16 x i32> %6478, i64 5, !dbg !280 + %6511 = icmp sle i32 %5573, %6510, !dbg !281 + %6512 = extractelement <16 x i32> %6478, i64 4, !dbg !280 + %6513 = icmp sle i32 %5573, %6512, !dbg !281 + %6514 = icmp sle i32 %5574, %6510, !dbg !281 + %6515 = icmp sle i32 %5574, %6512, !dbg !281 + %6516 = extractelement <16 x i32> %6478, i64 3, !dbg !280 + %6517 = icmp sle i32 %5573, %6516, !dbg !281 + %6518 = extractelement <16 x i32> %6478, i64 2, !dbg !280 + %6519 = icmp sle i32 %5573, %6518, !dbg !281 + %6520 = icmp sle i32 %5574, %6516, !dbg !281 + %6521 = icmp sle i32 %5574, %6518, !dbg !281 + %6522 = extractelement <16 x i32> %6478, i64 1, !dbg !280 + %6523 = icmp sle i32 %5573, %6522, !dbg !281 + %6524 = extractelement <16 x i32> %6478, i64 0, !dbg !280 + %6525 = icmp sle i32 %5573, %6524, !dbg !281 + %6526 = icmp sle i32 %5574, %6522, !dbg !281 + %6527 = icmp sle i32 %5574, %6524, !dbg !281 + %6528 = extractelement <16 x i1> %6479, i64 15, !dbg !282 + %6529 = and i1 %6528, %6481, !dbg !282 + %6530 = extractelement <16 x i1> %6479, i64 14, !dbg !282 + %6531 = and i1 %6530, %6483, !dbg !282 + %6532 = and i1 %6528, %6484, !dbg !282 + %6533 = and i1 %6530, %6485, !dbg !282 + %6534 = extractelement <16 x i1> %6479, i64 13, !dbg !282 + %6535 = and i1 %6534, %6487, !dbg !282 + %6536 = extractelement <16 x i1> %6479, i64 12, !dbg !282 + %6537 = and i1 %6536, %6489, !dbg !282 + %6538 = and i1 %6534, %6490, !dbg !282 + %6539 = and i1 %6536, %6491, !dbg !282 + %6540 = extractelement <16 x i1> %6479, i64 11, !dbg !282 + %6541 = and i1 %6540, %6493, !dbg !282 + %6542 = extractelement <16 x i1> %6479, i64 10, !dbg !282 + %6543 = and i1 %6542, %6495, !dbg !282 + %6544 = and i1 %6540, %6496, !dbg !282 + %6545 = and i1 %6542, %6497, !dbg !282 + %6546 = extractelement <16 x i1> %6479, i64 9, !dbg !282 + %6547 = and i1 %6546, %6499, !dbg !282 + %6548 = extractelement <16 x i1> %6479, i64 8, !dbg !282 + %6549 = and i1 %6548, %6501, !dbg !282 + %6550 = and i1 %6546, %6502, !dbg !282 + %6551 = and i1 %6548, %6503, !dbg !282 + %6552 = extractelement <16 x i1> %6479, i64 7, !dbg !282 + %6553 = and i1 %6552, %6505, !dbg !282 + %6554 = extractelement <16 x i1> %6479, i64 6, !dbg !282 + %6555 = and i1 %6554, %6507, !dbg !282 + %6556 = and i1 %6552, %6508, !dbg !282 + %6557 = and i1 %6554, %6509, !dbg !282 + %6558 = extractelement <16 x i1> %6479, i64 5, !dbg !282 + %6559 = and i1 %6558, %6511, !dbg !282 + %6560 = extractelement <16 x i1> %6479, i64 4, !dbg !282 + %6561 = and i1 %6560, %6513, !dbg !282 + %6562 = and i1 %6558, %6514, !dbg !282 + %6563 = and i1 %6560, %6515, !dbg !282 + %6564 = extractelement <16 x i1> %6479, i64 3, !dbg !282 + %6565 = and i1 %6564, %6517, !dbg !282 + %6566 = extractelement <16 x i1> %6479, i64 2, !dbg !282 + %6567 = and i1 %6566, %6519, !dbg !282 + %6568 = and i1 %6564, %6520, !dbg !282 + %6569 = and i1 %6566, %6521, !dbg !282 + %6570 = extractelement <16 x i1> %6479, i64 1, !dbg !282 + %6571 = and i1 %6570, %6523, !dbg !282 + %6572 = extractelement <16 x i1> %6479, i64 0, !dbg !282 + %6573 = and i1 %6572, %6525, !dbg !282 + %6574 = and i1 %6570, %6526, !dbg !282 + %6575 = and i1 %6572, %6527, !dbg !282 + %6576 = icmp sgt i32 %6480, -1, !dbg !280 + %6577 = icmp sgt i32 %6482, -1, !dbg !280 + %6578 = icmp sgt i32 %6486, -1, !dbg !280 + %6579 = icmp sgt i32 %6488, -1, !dbg !280 + %6580 = icmp sgt i32 %6492, -1, !dbg !280 + %6581 = icmp sgt i32 %6494, -1, !dbg !280 + %6582 = icmp sgt i32 %6498, -1, !dbg !280 + %6583 = icmp sgt i32 %6500, -1, !dbg !280 + %6584 = icmp sgt i32 %6504, -1, !dbg !280 + %6585 = icmp sgt i32 %6506, -1, !dbg !280 + %6586 = icmp sgt i32 %6510, -1, !dbg !280 + %6587 = icmp sgt i32 %6512, -1, !dbg !280 + %6588 = icmp sgt i32 %6516, -1, !dbg !280 + %6589 = icmp sgt i32 %6518, -1, !dbg !280 + %6590 = icmp sgt i32 %6522, -1, !dbg !280 + %6591 = icmp sgt i32 %6524, -1, !dbg !280 + %6592 = and <16 x i32> %6478, splat (i32 15), !dbg !283 + %6593 = icmp ne <16 x i32> %6592, zeroinitializer, !dbg !283 + %6594 = sdiv <16 x i32> %6478, splat (i32 16), !dbg !284 + %6595 = and <16 x i1> %6479, %6593, !dbg !285 + %6596 = sext <16 x i1> %6595 to <16 x i32>, !dbg !285 + %6597 = add nsw <16 x i32> %6594, %6596, !dbg !285 + %6598 = shufflevector <16 x i32> %6597, <16 x i32> poison, <32 x i32> , !dbg !285 + %6599 = icmp eq <32 x i32> %6598, %5035, !dbg !286 + %6600 = extractelement <32 x i1> %6599, i64 31, !dbg !287 + %6601 = and i1 %6576, %6600, !dbg !287 + %6602 = extractelement <32 x i1> %6599, i64 30, !dbg !287 + %6603 = and i1 %6577, %6602, !dbg !287 + %6604 = extractelement <32 x i1> %6599, i64 29, !dbg !287 + %6605 = and i1 %6576, %6604, !dbg !287 + %6606 = extractelement <32 x i1> %6599, i64 28, !dbg !287 + %6607 = and i1 %6577, %6606, !dbg !287 + %6608 = extractelement <32 x i1> %6599, i64 27, !dbg !287 + %6609 = and i1 %6578, %6608, !dbg !287 + %6610 = extractelement <32 x i1> %6599, i64 26, !dbg !287 + %6611 = and i1 %6579, %6610, !dbg !287 + %6612 = extractelement <32 x i1> %6599, i64 25, !dbg !287 + %6613 = and i1 %6578, %6612, !dbg !287 + %6614 = extractelement <32 x i1> %6599, i64 24, !dbg !287 + %6615 = and i1 %6579, %6614, !dbg !287 + %6616 = extractelement <32 x i1> %6599, i64 23, !dbg !287 + %6617 = and i1 %6580, %6616, !dbg !287 + %6618 = extractelement <32 x i1> %6599, i64 22, !dbg !287 + %6619 = and i1 %6581, %6618, !dbg !287 + %6620 = extractelement <32 x i1> %6599, i64 21, !dbg !287 + %6621 = and i1 %6580, %6620, !dbg !287 + %6622 = extractelement <32 x i1> %6599, i64 20, !dbg !287 + %6623 = and i1 %6581, %6622, !dbg !287 + %6624 = extractelement <32 x i1> %6599, i64 19, !dbg !287 + %6625 = and i1 %6582, %6624, !dbg !287 + %6626 = extractelement <32 x i1> %6599, i64 18, !dbg !287 + %6627 = and i1 %6583, %6626, !dbg !287 + %6628 = extractelement <32 x i1> %6599, i64 17, !dbg !287 + %6629 = and i1 %6582, %6628, !dbg !287 + %6630 = extractelement <32 x i1> %6599, i64 16, !dbg !287 + %6631 = and i1 %6583, %6630, !dbg !287 + %6632 = extractelement <32 x i1> %6599, i64 15, !dbg !287 + %6633 = and i1 %6584, %6632, !dbg !287 + %6634 = extractelement <32 x i1> %6599, i64 14, !dbg !287 + %6635 = and i1 %6585, %6634, !dbg !287 + %6636 = extractelement <32 x i1> %6599, i64 13, !dbg !287 + %6637 = and i1 %6584, %6636, !dbg !287 + %6638 = extractelement <32 x i1> %6599, i64 12, !dbg !287 + %6639 = and i1 %6585, %6638, !dbg !287 + %6640 = extractelement <32 x i1> %6599, i64 11, !dbg !287 + %6641 = and i1 %6586, %6640, !dbg !287 + %6642 = extractelement <32 x i1> %6599, i64 10, !dbg !287 + %6643 = and i1 %6587, %6642, !dbg !287 + %6644 = extractelement <32 x i1> %6599, i64 9, !dbg !287 + %6645 = and i1 %6586, %6644, !dbg !287 + %6646 = extractelement <32 x i1> %6599, i64 8, !dbg !287 + %6647 = and i1 %6587, %6646, !dbg !287 + %6648 = extractelement <32 x i1> %6599, i64 7, !dbg !287 + %6649 = and i1 %6588, %6648, !dbg !287 + %6650 = extractelement <32 x i1> %6599, i64 6, !dbg !287 + %6651 = and i1 %6589, %6650, !dbg !287 + %6652 = extractelement <32 x i1> %6599, i64 5, !dbg !287 + %6653 = and i1 %6588, %6652, !dbg !287 + %6654 = extractelement <32 x i1> %6599, i64 4, !dbg !287 + %6655 = and i1 %6589, %6654, !dbg !287 + %6656 = extractelement <32 x i1> %6599, i64 3, !dbg !287 + %6657 = and i1 %6590, %6656, !dbg !287 + %6658 = extractelement <32 x i1> %6599, i64 2, !dbg !287 + %6659 = and i1 %6591, %6658, !dbg !287 + %6660 = extractelement <32 x i1> %6599, i64 1, !dbg !287 + %6661 = and i1 %6590, %6660, !dbg !287 + %6662 = extractelement <32 x i1> %6599, i64 0, !dbg !287 + %6663 = and i1 %6591, %6662, !dbg !287 + %6664 = or i1 %6529, %6601, !dbg !288 + %6665 = or i1 %6531, %6603, !dbg !288 + %6666 = or i1 %6532, %6605, !dbg !288 + %6667 = or i1 %6533, %6607, !dbg !288 + %6668 = or i1 %6535, %6609, !dbg !288 + %6669 = or i1 %6537, %6611, !dbg !288 + %6670 = or i1 %6538, %6613, !dbg !288 + %6671 = or i1 %6539, %6615, !dbg !288 + %6672 = or i1 %6541, %6617, !dbg !288 + %6673 = or i1 %6543, %6619, !dbg !288 + %6674 = or i1 %6544, %6621, !dbg !288 + %6675 = or i1 %6545, %6623, !dbg !288 + %6676 = or i1 %6547, %6625, !dbg !288 + %6677 = or i1 %6549, %6627, !dbg !288 + %6678 = or i1 %6550, %6629, !dbg !288 + %6679 = or i1 %6551, %6631, !dbg !288 + %6680 = or i1 %6553, %6633, !dbg !288 + %6681 = or i1 %6555, %6635, !dbg !288 + %6682 = or i1 %6556, %6637, !dbg !288 + %6683 = or i1 %6557, %6639, !dbg !288 + %6684 = or i1 %6559, %6641, !dbg !288 + %6685 = or i1 %6561, %6643, !dbg !288 + %6686 = or i1 %6562, %6645, !dbg !288 + %6687 = or i1 %6563, %6647, !dbg !288 + %6688 = or i1 %6565, %6649, !dbg !288 + %6689 = or i1 %6567, %6651, !dbg !288 + %6690 = or i1 %6568, %6653, !dbg !288 + %6691 = or i1 %6569, %6655, !dbg !288 + %6692 = or i1 %6571, %6657, !dbg !288 + %6693 = or i1 %6573, %6659, !dbg !288 + %6694 = or i1 %6574, %6661, !dbg !288 + %6695 = or i1 %6575, %6663, !dbg !288 + %6696 = select i1 %6664, i1 %5975, i1 false, !dbg !289 + %6697 = select i1 %6665, i1 %5976, i1 false, !dbg !289 + %6698 = select i1 %6666, i1 %5975, i1 false, !dbg !289 + %6699 = select i1 %6667, i1 %5976, i1 false, !dbg !289 + %6700 = select i1 %6668, i1 %5977, i1 false, !dbg !289 + %6701 = select i1 %6669, i1 %5978, i1 false, !dbg !289 + %6702 = select i1 %6670, i1 %5977, i1 false, !dbg !289 + %6703 = select i1 %6671, i1 %5978, i1 false, !dbg !289 + %6704 = select i1 %6672, i1 %5979, i1 false, !dbg !289 + %6705 = select i1 %6673, i1 %5980, i1 false, !dbg !289 + %6706 = select i1 %6674, i1 %5979, i1 false, !dbg !289 + %6707 = select i1 %6675, i1 %5980, i1 false, !dbg !289 + %6708 = select i1 %6676, i1 %5981, i1 false, !dbg !289 + %6709 = select i1 %6677, i1 %5982, i1 false, !dbg !289 + %6710 = select i1 %6678, i1 %5981, i1 false, !dbg !289 + %6711 = select i1 %6679, i1 %5982, i1 false, !dbg !289 + %6712 = select i1 %6680, i1 %5983, i1 false, !dbg !289 + %6713 = select i1 %6681, i1 %5984, i1 false, !dbg !289 + %6714 = select i1 %6682, i1 %5983, i1 false, !dbg !289 + %6715 = select i1 %6683, i1 %5984, i1 false, !dbg !289 + %6716 = select i1 %6684, i1 %5985, i1 false, !dbg !289 + %6717 = select i1 %6685, i1 %5986, i1 false, !dbg !289 + %6718 = select i1 %6686, i1 %5985, i1 false, !dbg !289 + %6719 = select i1 %6687, i1 %5986, i1 false, !dbg !289 + %6720 = select i1 %6688, i1 %5987, i1 false, !dbg !289 + %6721 = select i1 %6689, i1 %5988, i1 false, !dbg !289 + %6722 = select i1 %6690, i1 %5987, i1 false, !dbg !289 + %6723 = select i1 %6691, i1 %5988, i1 false, !dbg !289 + %6724 = select i1 %6692, i1 %5989, i1 false, !dbg !289 + %6725 = select i1 %6693, i1 %5990, i1 false, !dbg !289 + %6726 = select i1 %6694, i1 %5989, i1 false, !dbg !289 + %6727 = select i1 %6695, i1 %5990, i1 false, !dbg !289 + %6728 = fmul float %6446, 0x3FF7154760000000, !dbg !290 + %6729 = select i1 %6696, float %6728, float 0xFFF0000000000000, !dbg !289 + %6730 = fmul float %6447, 0x3FF7154760000000, !dbg !290 + %6731 = select i1 %6697, float %6730, float 0xFFF0000000000000, !dbg !289 + %6732 = fmul float %6448, 0x3FF7154760000000, !dbg !290 + %6733 = select i1 %6698, float %6732, float 0xFFF0000000000000, !dbg !289 + %6734 = fmul float %6449, 0x3FF7154760000000, !dbg !290 + %6735 = select i1 %6699, float %6734, float 0xFFF0000000000000, !dbg !289 + %6736 = fmul float %6450, 0x3FF7154760000000, !dbg !290 + %6737 = select i1 %6700, float %6736, float 0xFFF0000000000000, !dbg !289 + %6738 = fmul float %6451, 0x3FF7154760000000, !dbg !290 + %6739 = select i1 %6701, float %6738, float 0xFFF0000000000000, !dbg !289 + %6740 = fmul float %6452, 0x3FF7154760000000, !dbg !290 + %6741 = select i1 %6702, float %6740, float 0xFFF0000000000000, !dbg !289 + %6742 = fmul float %6453, 0x3FF7154760000000, !dbg !290 + %6743 = select i1 %6703, float %6742, float 0xFFF0000000000000, !dbg !289 + %6744 = fmul float %6454, 0x3FF7154760000000, !dbg !290 + %6745 = select i1 %6704, float %6744, float 0xFFF0000000000000, !dbg !289 + %6746 = fmul float %6455, 0x3FF7154760000000, !dbg !290 + %6747 = select i1 %6705, float %6746, float 0xFFF0000000000000, !dbg !289 + %6748 = fmul float %6456, 0x3FF7154760000000, !dbg !290 + %6749 = select i1 %6706, float %6748, float 0xFFF0000000000000, !dbg !289 + %6750 = fmul float %6457, 0x3FF7154760000000, !dbg !290 + %6751 = select i1 %6707, float %6750, float 0xFFF0000000000000, !dbg !289 + %6752 = fmul float %6458, 0x3FF7154760000000, !dbg !290 + %6753 = select i1 %6708, float %6752, float 0xFFF0000000000000, !dbg !289 + %6754 = fmul float %6459, 0x3FF7154760000000, !dbg !290 + %6755 = select i1 %6709, float %6754, float 0xFFF0000000000000, !dbg !289 + %6756 = fmul float %6460, 0x3FF7154760000000, !dbg !290 + %6757 = select i1 %6710, float %6756, float 0xFFF0000000000000, !dbg !289 + %6758 = fmul float %6461, 0x3FF7154760000000, !dbg !290 + %6759 = select i1 %6711, float %6758, float 0xFFF0000000000000, !dbg !289 + %6760 = fmul float %6462, 0x3FF7154760000000, !dbg !290 + %6761 = select i1 %6712, float %6760, float 0xFFF0000000000000, !dbg !289 + %6762 = fmul float %6463, 0x3FF7154760000000, !dbg !290 + %6763 = select i1 %6713, float %6762, float 0xFFF0000000000000, !dbg !289 + %6764 = fmul float %6464, 0x3FF7154760000000, !dbg !290 + %6765 = select i1 %6714, float %6764, float 0xFFF0000000000000, !dbg !289 + %6766 = fmul float %6465, 0x3FF7154760000000, !dbg !290 + %6767 = select i1 %6715, float %6766, float 0xFFF0000000000000, !dbg !289 + %6768 = fmul float %6466, 0x3FF7154760000000, !dbg !290 + %6769 = select i1 %6716, float %6768, float 0xFFF0000000000000, !dbg !289 + %6770 = fmul float %6467, 0x3FF7154760000000, !dbg !290 + %6771 = select i1 %6717, float %6770, float 0xFFF0000000000000, !dbg !289 + %6772 = fmul float %6468, 0x3FF7154760000000, !dbg !290 + %6773 = select i1 %6718, float %6772, float 0xFFF0000000000000, !dbg !289 + %6774 = fmul float %6469, 0x3FF7154760000000, !dbg !290 + %6775 = select i1 %6719, float %6774, float 0xFFF0000000000000, !dbg !289 + %6776 = fmul float %6470, 0x3FF7154760000000, !dbg !290 + %6777 = select i1 %6720, float %6776, float 0xFFF0000000000000, !dbg !289 + %6778 = fmul float %6471, 0x3FF7154760000000, !dbg !290 + %6779 = select i1 %6721, float %6778, float 0xFFF0000000000000, !dbg !289 + %6780 = fmul float %6472, 0x3FF7154760000000, !dbg !290 + %6781 = select i1 %6722, float %6780, float 0xFFF0000000000000, !dbg !289 + %6782 = fmul float %6473, 0x3FF7154760000000, !dbg !290 + %6783 = select i1 %6723, float %6782, float 0xFFF0000000000000, !dbg !289 + %6784 = fmul float %6474, 0x3FF7154760000000, !dbg !290 + %6785 = select i1 %6724, float %6784, float 0xFFF0000000000000, !dbg !289 + %6786 = fmul float %6475, 0x3FF7154760000000, !dbg !290 + %6787 = select i1 %6725, float %6786, float 0xFFF0000000000000, !dbg !289 + %6788 = fmul float %6476, 0x3FF7154760000000, !dbg !290 + %6789 = select i1 %6726, float %6788, float 0xFFF0000000000000, !dbg !289 + %6790 = fmul float %6477, 0x3FF7154760000000, !dbg !290 + %6791 = select i1 %6727, float %6790, float 0xFFF0000000000000, !dbg !289 + %6792 = fsub float %6729, %6043, !dbg !291 + %6793 = fsub float %6731, %6044, !dbg !291 + %6794 = fsub float %6733, %6043, !dbg !291 + %6795 = fsub float %6735, %6044, !dbg !291 + %6796 = fsub float %6737, %6045, !dbg !291 + %6797 = fsub float %6739, %6046, !dbg !291 + %6798 = fsub float %6741, %6045, !dbg !291 + %6799 = fsub float %6743, %6046, !dbg !291 + %6800 = fsub float %6745, %6047, !dbg !291 + %6801 = fsub float %6747, %6048, !dbg !291 + %6802 = fsub float %6749, %6047, !dbg !291 + %6803 = fsub float %6751, %6048, !dbg !291 + %6804 = fsub float %6753, %6049, !dbg !291 + %6805 = fsub float %6755, %6050, !dbg !291 + %6806 = fsub float %6757, %6049, !dbg !291 + %6807 = fsub float %6759, %6050, !dbg !291 + %6808 = fsub float %6761, %6051, !dbg !291 + %6809 = fsub float %6763, %6052, !dbg !291 + %6810 = fsub float %6765, %6051, !dbg !291 + %6811 = fsub float %6767, %6052, !dbg !291 + %6812 = fsub float %6769, %6053, !dbg !291 + %6813 = fsub float %6771, %6054, !dbg !291 + %6814 = fsub float %6773, %6053, !dbg !291 + %6815 = fsub float %6775, %6054, !dbg !291 + %6816 = fsub float %6777, %6055, !dbg !291 + %6817 = fsub float %6779, %6056, !dbg !291 + %6818 = fsub float %6781, %6055, !dbg !291 + %6819 = fsub float %6783, %6056, !dbg !291 + %6820 = fsub float %6785, %6057, !dbg !291 + %6821 = fsub float %6787, %6058, !dbg !291 + %6822 = fsub float %6789, %6057, !dbg !291 + %6823 = fsub float %6791, %6058, !dbg !291 + %6824 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1322 = icmp eq i32 %6824, 0, !dbg !292 + br i1 %.not.i1322, label %6827, label %6825, !dbg !292 + +6825: ; preds = %.lr.ph1700 + %6826 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6792) #3, !dbg !292 + br label %__nv_exp2f.exit1324, !dbg !292 + +6827: ; preds = %.lr.ph1700 + %6828 = tail call float @llvm.nvvm.ex2.approx.f(float %6792) #3, !dbg !292 + br label %__nv_exp2f.exit1324, !dbg !292 + +__nv_exp2f.exit1324: ; preds = %6825, %6827 + %.0.i1323 = phi float [ %6826, %6825 ], [ %6828, %6827 ], !dbg !292 + %6829 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1325 = icmp eq i32 %6829, 0, !dbg !292 + br i1 %.not.i1325, label %6832, label %6830, !dbg !292 + +6830: ; preds = %__nv_exp2f.exit1324 + %6831 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6793) #3, !dbg !292 + br label %__nv_exp2f.exit1327, !dbg !292 + +6832: ; preds = %__nv_exp2f.exit1324 + %6833 = tail call float @llvm.nvvm.ex2.approx.f(float %6793) #3, !dbg !292 + br label %__nv_exp2f.exit1327, !dbg !292 + +__nv_exp2f.exit1327: ; preds = %6830, %6832 + %.0.i1326 = phi float [ %6831, %6830 ], [ %6833, %6832 ], !dbg !292 + %6834 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1328 = icmp eq i32 %6834, 0, !dbg !292 + br i1 %.not.i1328, label %6837, label %6835, !dbg !292 + +6835: ; preds = %__nv_exp2f.exit1327 + %6836 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6794) #3, !dbg !292 + br label %__nv_exp2f.exit1330, !dbg !292 + +6837: ; preds = %__nv_exp2f.exit1327 + %6838 = tail call float @llvm.nvvm.ex2.approx.f(float %6794) #3, !dbg !292 + br label %__nv_exp2f.exit1330, !dbg !292 + +__nv_exp2f.exit1330: ; preds = %6835, %6837 + %.0.i1329 = phi float [ %6836, %6835 ], [ %6838, %6837 ], !dbg !292 + %6839 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1331 = icmp eq i32 %6839, 0, !dbg !292 + br i1 %.not.i1331, label %6842, label %6840, !dbg !292 + +6840: ; preds = %__nv_exp2f.exit1330 + %6841 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6795) #3, !dbg !292 + br label %__nv_exp2f.exit1333, !dbg !292 + +6842: ; preds = %__nv_exp2f.exit1330 + %6843 = tail call float @llvm.nvvm.ex2.approx.f(float %6795) #3, !dbg !292 + br label %__nv_exp2f.exit1333, !dbg !292 + +__nv_exp2f.exit1333: ; preds = %6840, %6842 + %.0.i1332 = phi float [ %6841, %6840 ], [ %6843, %6842 ], !dbg !292 + %6844 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1334 = icmp eq i32 %6844, 0, !dbg !292 + br i1 %.not.i1334, label %6847, label %6845, !dbg !292 + +6845: ; preds = %__nv_exp2f.exit1333 + %6846 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6796) #3, !dbg !292 + br label %__nv_exp2f.exit1336, !dbg !292 + +6847: ; preds = %__nv_exp2f.exit1333 + %6848 = tail call float @llvm.nvvm.ex2.approx.f(float %6796) #3, !dbg !292 + br label %__nv_exp2f.exit1336, !dbg !292 + +__nv_exp2f.exit1336: ; preds = %6845, %6847 + %.0.i1335 = phi float [ %6846, %6845 ], [ %6848, %6847 ], !dbg !292 + %6849 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1337 = icmp eq i32 %6849, 0, !dbg !292 + br i1 %.not.i1337, label %6852, label %6850, !dbg !292 + +6850: ; preds = %__nv_exp2f.exit1336 + %6851 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6797) #3, !dbg !292 + br label %__nv_exp2f.exit1339, !dbg !292 + +6852: ; preds = %__nv_exp2f.exit1336 + %6853 = tail call float @llvm.nvvm.ex2.approx.f(float %6797) #3, !dbg !292 + br label %__nv_exp2f.exit1339, !dbg !292 + +__nv_exp2f.exit1339: ; preds = %6850, %6852 + %.0.i1338 = phi float [ %6851, %6850 ], [ %6853, %6852 ], !dbg !292 + %6854 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1340 = icmp eq i32 %6854, 0, !dbg !292 + br i1 %.not.i1340, label %6857, label %6855, !dbg !292 + +6855: ; preds = %__nv_exp2f.exit1339 + %6856 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6798) #3, !dbg !292 + br label %__nv_exp2f.exit1342, !dbg !292 + +6857: ; preds = %__nv_exp2f.exit1339 + %6858 = tail call float @llvm.nvvm.ex2.approx.f(float %6798) #3, !dbg !292 + br label %__nv_exp2f.exit1342, !dbg !292 + +__nv_exp2f.exit1342: ; preds = %6855, %6857 + %.0.i1341 = phi float [ %6856, %6855 ], [ %6858, %6857 ], !dbg !292 + %6859 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1343 = icmp eq i32 %6859, 0, !dbg !292 + br i1 %.not.i1343, label %6862, label %6860, !dbg !292 + +6860: ; preds = %__nv_exp2f.exit1342 + %6861 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6799) #3, !dbg !292 + br label %__nv_exp2f.exit1345, !dbg !292 + +6862: ; preds = %__nv_exp2f.exit1342 + %6863 = tail call float @llvm.nvvm.ex2.approx.f(float %6799) #3, !dbg !292 + br label %__nv_exp2f.exit1345, !dbg !292 + +__nv_exp2f.exit1345: ; preds = %6860, %6862 + %.0.i1344 = phi float [ %6861, %6860 ], [ %6863, %6862 ], !dbg !292 + %6864 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1346 = icmp eq i32 %6864, 0, !dbg !292 + br i1 %.not.i1346, label %6867, label %6865, !dbg !292 + +6865: ; preds = %__nv_exp2f.exit1345 + %6866 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6800) #3, !dbg !292 + br label %__nv_exp2f.exit1348, !dbg !292 + +6867: ; preds = %__nv_exp2f.exit1345 + %6868 = tail call float @llvm.nvvm.ex2.approx.f(float %6800) #3, !dbg !292 + br label %__nv_exp2f.exit1348, !dbg !292 + +__nv_exp2f.exit1348: ; preds = %6865, %6867 + %.0.i1347 = phi float [ %6866, %6865 ], [ %6868, %6867 ], !dbg !292 + %6869 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1349 = icmp eq i32 %6869, 0, !dbg !292 + br i1 %.not.i1349, label %6872, label %6870, !dbg !292 + +6870: ; preds = %__nv_exp2f.exit1348 + %6871 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6801) #3, !dbg !292 + br label %__nv_exp2f.exit1351, !dbg !292 + +6872: ; preds = %__nv_exp2f.exit1348 + %6873 = tail call float @llvm.nvvm.ex2.approx.f(float %6801) #3, !dbg !292 + br label %__nv_exp2f.exit1351, !dbg !292 + +__nv_exp2f.exit1351: ; preds = %6870, %6872 + %.0.i1350 = phi float [ %6871, %6870 ], [ %6873, %6872 ], !dbg !292 + %6874 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1352 = icmp eq i32 %6874, 0, !dbg !292 + br i1 %.not.i1352, label %6877, label %6875, !dbg !292 + +6875: ; preds = %__nv_exp2f.exit1351 + %6876 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6802) #3, !dbg !292 + br label %__nv_exp2f.exit1354, !dbg !292 + +6877: ; preds = %__nv_exp2f.exit1351 + %6878 = tail call float @llvm.nvvm.ex2.approx.f(float %6802) #3, !dbg !292 + br label %__nv_exp2f.exit1354, !dbg !292 + +__nv_exp2f.exit1354: ; preds = %6875, %6877 + %.0.i1353 = phi float [ %6876, %6875 ], [ %6878, %6877 ], !dbg !292 + %6879 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1355 = icmp eq i32 %6879, 0, !dbg !292 + br i1 %.not.i1355, label %6882, label %6880, !dbg !292 + +6880: ; preds = %__nv_exp2f.exit1354 + %6881 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6803) #3, !dbg !292 + br label %__nv_exp2f.exit1357, !dbg !292 + +6882: ; preds = %__nv_exp2f.exit1354 + %6883 = tail call float @llvm.nvvm.ex2.approx.f(float %6803) #3, !dbg !292 + br label %__nv_exp2f.exit1357, !dbg !292 + +__nv_exp2f.exit1357: ; preds = %6880, %6882 + %.0.i1356 = phi float [ %6881, %6880 ], [ %6883, %6882 ], !dbg !292 + %6884 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1358 = icmp eq i32 %6884, 0, !dbg !292 + br i1 %.not.i1358, label %6887, label %6885, !dbg !292 + +6885: ; preds = %__nv_exp2f.exit1357 + %6886 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6804) #3, !dbg !292 + br label %__nv_exp2f.exit1360, !dbg !292 + +6887: ; preds = %__nv_exp2f.exit1357 + %6888 = tail call float @llvm.nvvm.ex2.approx.f(float %6804) #3, !dbg !292 + br label %__nv_exp2f.exit1360, !dbg !292 + +__nv_exp2f.exit1360: ; preds = %6885, %6887 + %.0.i1359 = phi float [ %6886, %6885 ], [ %6888, %6887 ], !dbg !292 + %6889 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1361 = icmp eq i32 %6889, 0, !dbg !292 + br i1 %.not.i1361, label %6892, label %6890, !dbg !292 + +6890: ; preds = %__nv_exp2f.exit1360 + %6891 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6805) #3, !dbg !292 + br label %__nv_exp2f.exit1363, !dbg !292 + +6892: ; preds = %__nv_exp2f.exit1360 + %6893 = tail call float @llvm.nvvm.ex2.approx.f(float %6805) #3, !dbg !292 + br label %__nv_exp2f.exit1363, !dbg !292 + +__nv_exp2f.exit1363: ; preds = %6890, %6892 + %.0.i1362 = phi float [ %6891, %6890 ], [ %6893, %6892 ], !dbg !292 + %6894 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1364 = icmp eq i32 %6894, 0, !dbg !292 + br i1 %.not.i1364, label %6897, label %6895, !dbg !292 + +6895: ; preds = %__nv_exp2f.exit1363 + %6896 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6806) #3, !dbg !292 + br label %__nv_exp2f.exit1366, !dbg !292 + +6897: ; preds = %__nv_exp2f.exit1363 + %6898 = tail call float @llvm.nvvm.ex2.approx.f(float %6806) #3, !dbg !292 + br label %__nv_exp2f.exit1366, !dbg !292 + +__nv_exp2f.exit1366: ; preds = %6895, %6897 + %.0.i1365 = phi float [ %6896, %6895 ], [ %6898, %6897 ], !dbg !292 + %6899 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1367 = icmp eq i32 %6899, 0, !dbg !292 + br i1 %.not.i1367, label %6902, label %6900, !dbg !292 + +6900: ; preds = %__nv_exp2f.exit1366 + %6901 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6807) #3, !dbg !292 + br label %__nv_exp2f.exit1369, !dbg !292 + +6902: ; preds = %__nv_exp2f.exit1366 + %6903 = tail call float @llvm.nvvm.ex2.approx.f(float %6807) #3, !dbg !292 + br label %__nv_exp2f.exit1369, !dbg !292 + +__nv_exp2f.exit1369: ; preds = %6900, %6902 + %.0.i1368 = phi float [ %6901, %6900 ], [ %6903, %6902 ], !dbg !292 + %6904 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1370 = icmp eq i32 %6904, 0, !dbg !292 + br i1 %.not.i1370, label %6907, label %6905, !dbg !292 + +6905: ; preds = %__nv_exp2f.exit1369 + %6906 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6808) #3, !dbg !292 + br label %__nv_exp2f.exit1372, !dbg !292 + +6907: ; preds = %__nv_exp2f.exit1369 + %6908 = tail call float @llvm.nvvm.ex2.approx.f(float %6808) #3, !dbg !292 + br label %__nv_exp2f.exit1372, !dbg !292 + +__nv_exp2f.exit1372: ; preds = %6905, %6907 + %.0.i1371 = phi float [ %6906, %6905 ], [ %6908, %6907 ], !dbg !292 + %6909 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1373 = icmp eq i32 %6909, 0, !dbg !292 + br i1 %.not.i1373, label %6912, label %6910, !dbg !292 + +6910: ; preds = %__nv_exp2f.exit1372 + %6911 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6809) #3, !dbg !292 + br label %__nv_exp2f.exit1375, !dbg !292 + +6912: ; preds = %__nv_exp2f.exit1372 + %6913 = tail call float @llvm.nvvm.ex2.approx.f(float %6809) #3, !dbg !292 + br label %__nv_exp2f.exit1375, !dbg !292 + +__nv_exp2f.exit1375: ; preds = %6910, %6912 + %.0.i1374 = phi float [ %6911, %6910 ], [ %6913, %6912 ], !dbg !292 + %6914 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1376 = icmp eq i32 %6914, 0, !dbg !292 + br i1 %.not.i1376, label %6917, label %6915, !dbg !292 + +6915: ; preds = %__nv_exp2f.exit1375 + %6916 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6810) #3, !dbg !292 + br label %__nv_exp2f.exit1378, !dbg !292 + +6917: ; preds = %__nv_exp2f.exit1375 + %6918 = tail call float @llvm.nvvm.ex2.approx.f(float %6810) #3, !dbg !292 + br label %__nv_exp2f.exit1378, !dbg !292 + +__nv_exp2f.exit1378: ; preds = %6915, %6917 + %.0.i1377 = phi float [ %6916, %6915 ], [ %6918, %6917 ], !dbg !292 + %6919 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1379 = icmp eq i32 %6919, 0, !dbg !292 + br i1 %.not.i1379, label %6922, label %6920, !dbg !292 + +6920: ; preds = %__nv_exp2f.exit1378 + %6921 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6811) #3, !dbg !292 + br label %__nv_exp2f.exit1381, !dbg !292 + +6922: ; preds = %__nv_exp2f.exit1378 + %6923 = tail call float @llvm.nvvm.ex2.approx.f(float %6811) #3, !dbg !292 + br label %__nv_exp2f.exit1381, !dbg !292 + +__nv_exp2f.exit1381: ; preds = %6920, %6922 + %.0.i1380 = phi float [ %6921, %6920 ], [ %6923, %6922 ], !dbg !292 + %6924 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1382 = icmp eq i32 %6924, 0, !dbg !292 + br i1 %.not.i1382, label %6927, label %6925, !dbg !292 + +6925: ; preds = %__nv_exp2f.exit1381 + %6926 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6812) #3, !dbg !292 + br label %__nv_exp2f.exit1384, !dbg !292 + +6927: ; preds = %__nv_exp2f.exit1381 + %6928 = tail call float @llvm.nvvm.ex2.approx.f(float %6812) #3, !dbg !292 + br label %__nv_exp2f.exit1384, !dbg !292 + +__nv_exp2f.exit1384: ; preds = %6925, %6927 + %.0.i1383 = phi float [ %6926, %6925 ], [ %6928, %6927 ], !dbg !292 + %6929 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1385 = icmp eq i32 %6929, 0, !dbg !292 + br i1 %.not.i1385, label %6932, label %6930, !dbg !292 + +6930: ; preds = %__nv_exp2f.exit1384 + %6931 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6813) #3, !dbg !292 + br label %__nv_exp2f.exit1387, !dbg !292 + +6932: ; preds = %__nv_exp2f.exit1384 + %6933 = tail call float @llvm.nvvm.ex2.approx.f(float %6813) #3, !dbg !292 + br label %__nv_exp2f.exit1387, !dbg !292 + +__nv_exp2f.exit1387: ; preds = %6930, %6932 + %.0.i1386 = phi float [ %6931, %6930 ], [ %6933, %6932 ], !dbg !292 + %6934 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1388 = icmp eq i32 %6934, 0, !dbg !292 + br i1 %.not.i1388, label %6937, label %6935, !dbg !292 + +6935: ; preds = %__nv_exp2f.exit1387 + %6936 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6814) #3, !dbg !292 + br label %__nv_exp2f.exit1390, !dbg !292 + +6937: ; preds = %__nv_exp2f.exit1387 + %6938 = tail call float @llvm.nvvm.ex2.approx.f(float %6814) #3, !dbg !292 + br label %__nv_exp2f.exit1390, !dbg !292 + +__nv_exp2f.exit1390: ; preds = %6935, %6937 + %.0.i1389 = phi float [ %6936, %6935 ], [ %6938, %6937 ], !dbg !292 + %6939 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1391 = icmp eq i32 %6939, 0, !dbg !292 + br i1 %.not.i1391, label %6942, label %6940, !dbg !292 + +6940: ; preds = %__nv_exp2f.exit1390 + %6941 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6815) #3, !dbg !292 + br label %__nv_exp2f.exit1393, !dbg !292 + +6942: ; preds = %__nv_exp2f.exit1390 + %6943 = tail call float @llvm.nvvm.ex2.approx.f(float %6815) #3, !dbg !292 + br label %__nv_exp2f.exit1393, !dbg !292 + +__nv_exp2f.exit1393: ; preds = %6940, %6942 + %.0.i1392 = phi float [ %6941, %6940 ], [ %6943, %6942 ], !dbg !292 + %6944 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1394 = icmp eq i32 %6944, 0, !dbg !292 + br i1 %.not.i1394, label %6947, label %6945, !dbg !292 + +6945: ; preds = %__nv_exp2f.exit1393 + %6946 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6816) #3, !dbg !292 + br label %__nv_exp2f.exit1396, !dbg !292 + +6947: ; preds = %__nv_exp2f.exit1393 + %6948 = tail call float @llvm.nvvm.ex2.approx.f(float %6816) #3, !dbg !292 + br label %__nv_exp2f.exit1396, !dbg !292 + +__nv_exp2f.exit1396: ; preds = %6945, %6947 + %.0.i1395 = phi float [ %6946, %6945 ], [ %6948, %6947 ], !dbg !292 + %6949 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1397 = icmp eq i32 %6949, 0, !dbg !292 + br i1 %.not.i1397, label %6952, label %6950, !dbg !292 + +6950: ; preds = %__nv_exp2f.exit1396 + %6951 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6817) #3, !dbg !292 + br label %__nv_exp2f.exit1399, !dbg !292 + +6952: ; preds = %__nv_exp2f.exit1396 + %6953 = tail call float @llvm.nvvm.ex2.approx.f(float %6817) #3, !dbg !292 + br label %__nv_exp2f.exit1399, !dbg !292 + +__nv_exp2f.exit1399: ; preds = %6950, %6952 + %.0.i1398 = phi float [ %6951, %6950 ], [ %6953, %6952 ], !dbg !292 + %6954 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1400 = icmp eq i32 %6954, 0, !dbg !292 + br i1 %.not.i1400, label %6957, label %6955, !dbg !292 + +6955: ; preds = %__nv_exp2f.exit1399 + %6956 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6818) #3, !dbg !292 + br label %__nv_exp2f.exit1402, !dbg !292 + +6957: ; preds = %__nv_exp2f.exit1399 + %6958 = tail call float @llvm.nvvm.ex2.approx.f(float %6818) #3, !dbg !292 + br label %__nv_exp2f.exit1402, !dbg !292 + +__nv_exp2f.exit1402: ; preds = %6955, %6957 + %.0.i1401 = phi float [ %6956, %6955 ], [ %6958, %6957 ], !dbg !292 + %6959 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1403 = icmp eq i32 %6959, 0, !dbg !292 + br i1 %.not.i1403, label %6962, label %6960, !dbg !292 + +6960: ; preds = %__nv_exp2f.exit1402 + %6961 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6819) #3, !dbg !292 + br label %__nv_exp2f.exit1405, !dbg !292 + +6962: ; preds = %__nv_exp2f.exit1402 + %6963 = tail call float @llvm.nvvm.ex2.approx.f(float %6819) #3, !dbg !292 + br label %__nv_exp2f.exit1405, !dbg !292 + +__nv_exp2f.exit1405: ; preds = %6960, %6962 + %.0.i1404 = phi float [ %6961, %6960 ], [ %6963, %6962 ], !dbg !292 + %6964 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1406 = icmp eq i32 %6964, 0, !dbg !292 + br i1 %.not.i1406, label %6967, label %6965, !dbg !292 + +6965: ; preds = %__nv_exp2f.exit1405 + %6966 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6820) #3, !dbg !292 + br label %__nv_exp2f.exit1408, !dbg !292 + +6967: ; preds = %__nv_exp2f.exit1405 + %6968 = tail call float @llvm.nvvm.ex2.approx.f(float %6820) #3, !dbg !292 + br label %__nv_exp2f.exit1408, !dbg !292 + +__nv_exp2f.exit1408: ; preds = %6965, %6967 + %.0.i1407 = phi float [ %6966, %6965 ], [ %6968, %6967 ], !dbg !292 + %6969 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1409 = icmp eq i32 %6969, 0, !dbg !292 + br i1 %.not.i1409, label %6972, label %6970, !dbg !292 + +6970: ; preds = %__nv_exp2f.exit1408 + %6971 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6821) #3, !dbg !292 + br label %__nv_exp2f.exit1411, !dbg !292 + +6972: ; preds = %__nv_exp2f.exit1408 + %6973 = tail call float @llvm.nvvm.ex2.approx.f(float %6821) #3, !dbg !292 + br label %__nv_exp2f.exit1411, !dbg !292 + +__nv_exp2f.exit1411: ; preds = %6970, %6972 + %.0.i1410 = phi float [ %6971, %6970 ], [ %6973, %6972 ], !dbg !292 + %6974 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1412 = icmp eq i32 %6974, 0, !dbg !292 + br i1 %.not.i1412, label %6977, label %6975, !dbg !292 + +6975: ; preds = %__nv_exp2f.exit1411 + %6976 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6822) #3, !dbg !292 + br label %__nv_exp2f.exit1414, !dbg !292 + +6977: ; preds = %__nv_exp2f.exit1411 + %6978 = tail call float @llvm.nvvm.ex2.approx.f(float %6822) #3, !dbg !292 + br label %__nv_exp2f.exit1414, !dbg !292 + +__nv_exp2f.exit1414: ; preds = %6975, %6977 + %.0.i1413 = phi float [ %6976, %6975 ], [ %6978, %6977 ], !dbg !292 + %6979 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !292 + %.not.i1415 = icmp eq i32 %6979, 0, !dbg !292 + br i1 %.not.i1415, label %6982, label %6980, !dbg !292 + +6980: ; preds = %__nv_exp2f.exit1414 + %6981 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6823) #3, !dbg !292 + br label %__nv_exp2f.exit1417, !dbg !292 + +6982: ; preds = %__nv_exp2f.exit1414 + %6983 = tail call float @llvm.nvvm.ex2.approx.f(float %6823) #3, !dbg !292 + br label %__nv_exp2f.exit1417, !dbg !292 + +__nv_exp2f.exit1417: ; preds = %6980, %6982 + %.0.i1416 = phi float [ %6981, %6980 ], [ %6983, %6982 ], !dbg !292 + %6984 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %5991, !dbg !269 + %6985 = insertelement <2 x float> poison, float %.0.i1323, i64 0, !dbg !293 + %6986 = insertelement <2 x float> %6985, float %.0.i1326, i64 1, !dbg !293 + %6987 = fptrunc <2 x float> %6986 to <2 x bfloat>, !dbg !293 + %6988 = insertelement <2 x float> poison, float %.0.i1329, i64 0, !dbg !293 + %6989 = insertelement <2 x float> %6988, float %.0.i1332, i64 1, !dbg !293 + %6990 = fptrunc <2 x float> %6989 to <2 x bfloat>, !dbg !293 + %6991 = insertelement <2 x float> poison, float %.0.i1335, i64 0, !dbg !293 + %6992 = insertelement <2 x float> %6991, float %.0.i1338, i64 1, !dbg !293 + %6993 = fptrunc <2 x float> %6992 to <2 x bfloat>, !dbg !293 + %6994 = insertelement <2 x float> poison, float %.0.i1341, i64 0, !dbg !293 + %6995 = insertelement <2 x float> %6994, float %.0.i1344, i64 1, !dbg !293 + %6996 = fptrunc <2 x float> %6995 to <2 x bfloat>, !dbg !293 + %6997 = insertelement <2 x float> poison, float %.0.i1347, i64 0, !dbg !293 + %6998 = insertelement <2 x float> %6997, float %.0.i1350, i64 1, !dbg !293 + %6999 = fptrunc <2 x float> %6998 to <2 x bfloat>, !dbg !293 + %7000 = insertelement <2 x float> poison, float %.0.i1353, i64 0, !dbg !293 + %7001 = insertelement <2 x float> %7000, float %.0.i1356, i64 1, !dbg !293 + %7002 = fptrunc <2 x float> %7001 to <2 x bfloat>, !dbg !293 + %7003 = insertelement <2 x float> poison, float %.0.i1359, i64 0, !dbg !293 + %7004 = insertelement <2 x float> %7003, float %.0.i1362, i64 1, !dbg !293 + %7005 = fptrunc <2 x float> %7004 to <2 x bfloat>, !dbg !293 + %7006 = insertelement <2 x float> poison, float %.0.i1365, i64 0, !dbg !293 + %7007 = insertelement <2 x float> %7006, float %.0.i1368, i64 1, !dbg !293 + %7008 = fptrunc <2 x float> %7007 to <2 x bfloat>, !dbg !293 + %7009 = insertelement <2 x float> poison, float %.0.i1371, i64 0, !dbg !293 + %7010 = insertelement <2 x float> %7009, float %.0.i1374, i64 1, !dbg !293 + %7011 = fptrunc <2 x float> %7010 to <2 x bfloat>, !dbg !293 + %7012 = insertelement <2 x float> poison, float %.0.i1377, i64 0, !dbg !293 + %7013 = insertelement <2 x float> %7012, float %.0.i1380, i64 1, !dbg !293 + %7014 = fptrunc <2 x float> %7013 to <2 x bfloat>, !dbg !293 + %7015 = insertelement <2 x float> poison, float %.0.i1383, i64 0, !dbg !293 + %7016 = insertelement <2 x float> %7015, float %.0.i1386, i64 1, !dbg !293 + %7017 = fptrunc <2 x float> %7016 to <2 x bfloat>, !dbg !293 + %7018 = insertelement <2 x float> poison, float %.0.i1389, i64 0, !dbg !293 + %7019 = insertelement <2 x float> %7018, float %.0.i1392, i64 1, !dbg !293 + %7020 = fptrunc <2 x float> %7019 to <2 x bfloat>, !dbg !293 + %7021 = insertelement <2 x float> poison, float %.0.i1395, i64 0, !dbg !293 + %7022 = insertelement <2 x float> %7021, float %.0.i1398, i64 1, !dbg !293 + %7023 = fptrunc <2 x float> %7022 to <2 x bfloat>, !dbg !293 + %7024 = insertelement <2 x float> poison, float %.0.i1401, i64 0, !dbg !293 + %7025 = insertelement <2 x float> %7024, float %.0.i1404, i64 1, !dbg !293 + %7026 = fptrunc <2 x float> %7025 to <2 x bfloat>, !dbg !293 + %7027 = insertelement <2 x float> poison, float %.0.i1407, i64 0, !dbg !293 + %7028 = insertelement <2 x float> %7027, float %.0.i1410, i64 1, !dbg !293 + %7029 = fptrunc <2 x float> %7028 to <2 x bfloat>, !dbg !293 + %7030 = insertelement <2 x float> poison, float %.0.i1413, i64 0, !dbg !293 + %7031 = insertelement <2 x float> %7030, float %.0.i1416, i64 1, !dbg !293 + %7032 = fptrunc <2 x float> %7031 to <2 x bfloat>, !dbg !293 + %7033 = bitcast <2 x bfloat> %6987 to i32, !dbg !294 + %7034 = bitcast <2 x bfloat> %6990 to i32, !dbg !294 + %7035 = bitcast <2 x bfloat> %6993 to i32, !dbg !294 + %7036 = bitcast <2 x bfloat> %6996 to i32, !dbg !294 + %7037 = bitcast <2 x bfloat> %6999 to i32, !dbg !294 + %7038 = bitcast <2 x bfloat> %7002 to i32, !dbg !294 + %7039 = bitcast <2 x bfloat> %7005 to i32, !dbg !294 + %7040 = bitcast <2 x bfloat> %7008 to i32, !dbg !294 + %7041 = bitcast <2 x bfloat> %7011 to i32, !dbg !294 + %7042 = bitcast <2 x bfloat> %7014 to i32, !dbg !294 + %7043 = bitcast <2 x bfloat> %7017 to i32, !dbg !294 + %7044 = bitcast <2 x bfloat> %7020 to i32, !dbg !294 + %7045 = bitcast <2 x bfloat> %7023 to i32, !dbg !294 + %7046 = bitcast <2 x bfloat> %7026 to i32, !dbg !294 + %7047 = bitcast <2 x bfloat> %7029 to i32, !dbg !294 + %7048 = bitcast <2 x bfloat> %7032 to i32, !dbg !294 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !294 + %7049 = ptrtoint ptr addrspace(3) %6984 to i32, !dbg !294 + %7050 = lshr exact i32 %7049, 4, !dbg !294 + %7051 = and i32 %7050, 16383, !dbg !294 + %7052 = zext nneg i32 %7051 to i64, !dbg !294 + %7053 = or disjoint i64 %7052, 4611686293338849280, !dbg !294 + %7054 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5837, float %5838, float %5839, float %5840, float %5841, float %5842, float %5843, float %5844, float %5845, float %5846, float %5847, float %5848, float %5849, float %5850, float %5851, float %5852, float %5853, float %5854, float %5855, float %5856, float %5857, float %5858, float %5859, float %5860, float %5861, float %5862, float %5863, float %5864, float %5865, float %5866, float %5867, float %5868, float %5869, float %5870, float %5871, float %5872, float %5873, float %5874, float %5875, float %5876, float %5877, float %5878, float %5879, float %5880, float %5881, float %5882, float %5883, float %5884, float %5885, float %5886, float %5887, float %5888, float %5889, float %5890, float %5891, float %5892, float %5893, float %5894, float %5895, float %5896, float %5897, float %5898, float %5899, float %5900, i32 %7033, i32 %7034, i32 %7035, i32 %7036, i64 %7053, i1 true) #3, !dbg !294 + %7055 = add i32 %7049, 2048, !dbg !294 + %7056 = lshr exact i32 %7055, 4, !dbg !294 + %7057 = and i32 %7056, 16383, !dbg !294 + %7058 = zext nneg i32 %7057 to i64, !dbg !294 + %7059 = or disjoint i64 %7058, 4611686293338849280, !dbg !294 + %7060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 0, !dbg !294 + %7061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 1, !dbg !294 + %7062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 2, !dbg !294 + %7063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 3, !dbg !294 + %7064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 4, !dbg !294 + %7065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 5, !dbg !294 + %7066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 6, !dbg !294 + %7067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 7, !dbg !294 + %7068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 8, !dbg !294 + %7069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 9, !dbg !294 + %7070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 10, !dbg !294 + %7071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 11, !dbg !294 + %7072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 12, !dbg !294 + %7073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 13, !dbg !294 + %7074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 14, !dbg !294 + %7075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 15, !dbg !294 + %7076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 16, !dbg !294 + %7077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 17, !dbg !294 + %7078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 18, !dbg !294 + %7079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 19, !dbg !294 + %7080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 20, !dbg !294 + %7081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 21, !dbg !294 + %7082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 22, !dbg !294 + %7083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 23, !dbg !294 + %7084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 24, !dbg !294 + %7085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 25, !dbg !294 + %7086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 26, !dbg !294 + %7087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 27, !dbg !294 + %7088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 28, !dbg !294 + %7089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 29, !dbg !294 + %7090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 30, !dbg !294 + %7091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 31, !dbg !294 + %7092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 32, !dbg !294 + %7093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 33, !dbg !294 + %7094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 34, !dbg !294 + %7095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 35, !dbg !294 + %7096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 36, !dbg !294 + %7097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 37, !dbg !294 + %7098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 38, !dbg !294 + %7099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 39, !dbg !294 + %7100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 40, !dbg !294 + %7101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 41, !dbg !294 + %7102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 42, !dbg !294 + %7103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 43, !dbg !294 + %7104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 44, !dbg !294 + %7105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 45, !dbg !294 + %7106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 46, !dbg !294 + %7107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 47, !dbg !294 + %7108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 48, !dbg !294 + %7109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 49, !dbg !294 + %7110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 50, !dbg !294 + %7111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 51, !dbg !294 + %7112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 52, !dbg !294 + %7113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 53, !dbg !294 + %7114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 54, !dbg !294 + %7115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 55, !dbg !294 + %7116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 56, !dbg !294 + %7117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 57, !dbg !294 + %7118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 58, !dbg !294 + %7119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 59, !dbg !294 + %7120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 60, !dbg !294 + %7121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 61, !dbg !294 + %7122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 62, !dbg !294 + %7123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7054, 63, !dbg !294 + %7124 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7060, float %7061, float %7062, float %7063, float %7064, float %7065, float %7066, float %7067, float %7068, float %7069, float %7070, float %7071, float %7072, float %7073, float %7074, float %7075, float %7076, float %7077, float %7078, float %7079, float %7080, float %7081, float %7082, float %7083, float %7084, float %7085, float %7086, float %7087, float %7088, float %7089, float %7090, float %7091, float %7092, float %7093, float %7094, float %7095, float %7096, float %7097, float %7098, float %7099, float %7100, float %7101, float %7102, float %7103, float %7104, float %7105, float %7106, float %7107, float %7108, float %7109, float %7110, float %7111, float %7112, float %7113, float %7114, float %7115, float %7116, float %7117, float %7118, float %7119, float %7120, float %7121, float %7122, float %7123, i32 %7037, i32 %7038, i32 %7039, i32 %7040, i64 %7059, i1 true) #3, !dbg !294 + %7125 = add i32 %7049, 4096, !dbg !294 + %7126 = lshr exact i32 %7125, 4, !dbg !294 + %7127 = and i32 %7126, 16383, !dbg !294 + %7128 = zext nneg i32 %7127 to i64, !dbg !294 + %7129 = or disjoint i64 %7128, 4611686293338849280, !dbg !294 + %7130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 0, !dbg !294 + %7131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 1, !dbg !294 + %7132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 2, !dbg !294 + %7133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 3, !dbg !294 + %7134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 4, !dbg !294 + %7135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 5, !dbg !294 + %7136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 6, !dbg !294 + %7137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 7, !dbg !294 + %7138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 8, !dbg !294 + %7139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 9, !dbg !294 + %7140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 10, !dbg !294 + %7141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 11, !dbg !294 + %7142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 12, !dbg !294 + %7143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 13, !dbg !294 + %7144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 14, !dbg !294 + %7145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 15, !dbg !294 + %7146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 16, !dbg !294 + %7147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 17, !dbg !294 + %7148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 18, !dbg !294 + %7149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 19, !dbg !294 + %7150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 20, !dbg !294 + %7151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 21, !dbg !294 + %7152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 22, !dbg !294 + %7153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 23, !dbg !294 + %7154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 24, !dbg !294 + %7155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 25, !dbg !294 + %7156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 26, !dbg !294 + %7157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 27, !dbg !294 + %7158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 28, !dbg !294 + %7159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 29, !dbg !294 + %7160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 30, !dbg !294 + %7161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 31, !dbg !294 + %7162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 32, !dbg !294 + %7163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 33, !dbg !294 + %7164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 34, !dbg !294 + %7165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 35, !dbg !294 + %7166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 36, !dbg !294 + %7167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 37, !dbg !294 + %7168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 38, !dbg !294 + %7169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 39, !dbg !294 + %7170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 40, !dbg !294 + %7171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 41, !dbg !294 + %7172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 42, !dbg !294 + %7173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 43, !dbg !294 + %7174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 44, !dbg !294 + %7175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 45, !dbg !294 + %7176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 46, !dbg !294 + %7177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 47, !dbg !294 + %7178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 48, !dbg !294 + %7179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 49, !dbg !294 + %7180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 50, !dbg !294 + %7181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 51, !dbg !294 + %7182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 52, !dbg !294 + %7183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 53, !dbg !294 + %7184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 54, !dbg !294 + %7185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 55, !dbg !294 + %7186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 56, !dbg !294 + %7187 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 57, !dbg !294 + %7188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 58, !dbg !294 + %7189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 59, !dbg !294 + %7190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 60, !dbg !294 + %7191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 61, !dbg !294 + %7192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 62, !dbg !294 + %7193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7124, 63, !dbg !294 + %7194 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7130, float %7131, float %7132, float %7133, float %7134, float %7135, float %7136, float %7137, float %7138, float %7139, float %7140, float %7141, float %7142, float %7143, float %7144, float %7145, float %7146, float %7147, float %7148, float %7149, float %7150, float %7151, float %7152, float %7153, float %7154, float %7155, float %7156, float %7157, float %7158, float %7159, float %7160, float %7161, float %7162, float %7163, float %7164, float %7165, float %7166, float %7167, float %7168, float %7169, float %7170, float %7171, float %7172, float %7173, float %7174, float %7175, float %7176, float %7177, float %7178, float %7179, float %7180, float %7181, float %7182, float %7183, float %7184, float %7185, float %7186, float %7187, float %7188, float %7189, float %7190, float %7191, float %7192, float %7193, i32 %7041, i32 %7042, i32 %7043, i32 %7044, i64 %7129, i1 true) #3, !dbg !294 + %7195 = add i32 %7049, 6144, !dbg !294 + %7196 = lshr exact i32 %7195, 4, !dbg !294 + %7197 = and i32 %7196, 16383, !dbg !294 + %7198 = zext nneg i32 %7197 to i64, !dbg !294 + %7199 = or disjoint i64 %7198, 4611686293338849280, !dbg !294 + %7200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 0, !dbg !294 + %7201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 1, !dbg !294 + %7202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 2, !dbg !294 + %7203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 3, !dbg !294 + %7204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 4, !dbg !294 + %7205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 5, !dbg !294 + %7206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 6, !dbg !294 + %7207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 7, !dbg !294 + %7208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 8, !dbg !294 + %7209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 9, !dbg !294 + %7210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 10, !dbg !294 + %7211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 11, !dbg !294 + %7212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 12, !dbg !294 + %7213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 13, !dbg !294 + %7214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 14, !dbg !294 + %7215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 15, !dbg !294 + %7216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 16, !dbg !294 + %7217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 17, !dbg !294 + %7218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 18, !dbg !294 + %7219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 19, !dbg !294 + %7220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 20, !dbg !294 + %7221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 21, !dbg !294 + %7222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 22, !dbg !294 + %7223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 23, !dbg !294 + %7224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 24, !dbg !294 + %7225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 25, !dbg !294 + %7226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 26, !dbg !294 + %7227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 27, !dbg !294 + %7228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 28, !dbg !294 + %7229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 29, !dbg !294 + %7230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 30, !dbg !294 + %7231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 31, !dbg !294 + %7232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 32, !dbg !294 + %7233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 33, !dbg !294 + %7234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 34, !dbg !294 + %7235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 35, !dbg !294 + %7236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 36, !dbg !294 + %7237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 37, !dbg !294 + %7238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 38, !dbg !294 + %7239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 39, !dbg !294 + %7240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 40, !dbg !294 + %7241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 41, !dbg !294 + %7242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 42, !dbg !294 + %7243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 43, !dbg !294 + %7244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 44, !dbg !294 + %7245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 45, !dbg !294 + %7246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 46, !dbg !294 + %7247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 47, !dbg !294 + %7248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 48, !dbg !294 + %7249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 49, !dbg !294 + %7250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 50, !dbg !294 + %7251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 51, !dbg !294 + %7252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 52, !dbg !294 + %7253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 53, !dbg !294 + %7254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 54, !dbg !294 + %7255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 55, !dbg !294 + %7256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 56, !dbg !294 + %7257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 57, !dbg !294 + %7258 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 58, !dbg !294 + %7259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 59, !dbg !294 + %7260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 60, !dbg !294 + %7261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 61, !dbg !294 + %7262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 62, !dbg !294 + %7263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7194, 63, !dbg !294 + %7264 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7200, float %7201, float %7202, float %7203, float %7204, float %7205, float %7206, float %7207, float %7208, float %7209, float %7210, float %7211, float %7212, float %7213, float %7214, float %7215, float %7216, float %7217, float %7218, float %7219, float %7220, float %7221, float %7222, float %7223, float %7224, float %7225, float %7226, float %7227, float %7228, float %7229, float %7230, float %7231, float %7232, float %7233, float %7234, float %7235, float %7236, float %7237, float %7238, float %7239, float %7240, float %7241, float %7242, float %7243, float %7244, float %7245, float %7246, float %7247, float %7248, float %7249, float %7250, float %7251, float %7252, float %7253, float %7254, float %7255, float %7256, float %7257, float %7258, float %7259, float %7260, float %7261, float %7262, float %7263, i32 %7045, i32 %7046, i32 %7047, i32 %7048, i64 %7199, i1 true) #3, !dbg !294 + %7265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 0, !dbg !294 + %7266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 1, !dbg !294 + %7267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 2, !dbg !294 + %7268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 3, !dbg !294 + %7269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 4, !dbg !294 + %7270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 5, !dbg !294 + %7271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 6, !dbg !294 + %7272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 7, !dbg !294 + %7273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 8, !dbg !294 + %7274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 9, !dbg !294 + %7275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 10, !dbg !294 + %7276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 11, !dbg !294 + %7277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 12, !dbg !294 + %7278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 13, !dbg !294 + %7279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 14, !dbg !294 + %7280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 15, !dbg !294 + %7281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 16, !dbg !294 + %7282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 17, !dbg !294 + %7283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 18, !dbg !294 + %7284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 19, !dbg !294 + %7285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 20, !dbg !294 + %7286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 21, !dbg !294 + %7287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 22, !dbg !294 + %7288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 23, !dbg !294 + %7289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 24, !dbg !294 + %7290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 25, !dbg !294 + %7291 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 26, !dbg !294 + %7292 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 27, !dbg !294 + %7293 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 28, !dbg !294 + %7294 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 29, !dbg !294 + %7295 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 30, !dbg !294 + %7296 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 31, !dbg !294 + %7297 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 32, !dbg !294 + %7298 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 33, !dbg !294 + %7299 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 34, !dbg !294 + %7300 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 35, !dbg !294 + %7301 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 36, !dbg !294 + %7302 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 37, !dbg !294 + %7303 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 38, !dbg !294 + %7304 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 39, !dbg !294 + %7305 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 40, !dbg !294 + %7306 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 41, !dbg !294 + %7307 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 42, !dbg !294 + %7308 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 43, !dbg !294 + %7309 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 44, !dbg !294 + %7310 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 45, !dbg !294 + %7311 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 46, !dbg !294 + %7312 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 47, !dbg !294 + %7313 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 48, !dbg !294 + %7314 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 49, !dbg !294 + %7315 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 50, !dbg !294 + %7316 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 51, !dbg !294 + %7317 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 52, !dbg !294 + %7318 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 53, !dbg !294 + %7319 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 54, !dbg !294 + %7320 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 55, !dbg !294 + %7321 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 56, !dbg !294 + %7322 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 57, !dbg !294 + %7323 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 58, !dbg !294 + %7324 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 59, !dbg !294 + %7325 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 60, !dbg !294 + %7326 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 61, !dbg !294 + %7327 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 62, !dbg !294 + %7328 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7264, 63, !dbg !294 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !294 + %7329 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %5993, !dbg !271 + %7330 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5179, !dbg !271 + %7331 = load float, ptr addrspace(3) %7330, align 8, !dbg !271 + %7332 = getelementptr inbounds nuw i8, ptr addrspace(3) %7330, i32 4, !dbg !271 + %7333 = load float, ptr addrspace(3) %7332, align 4, !dbg !271 + %7334 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5185, !dbg !271 + %7335 = load float, ptr addrspace(3) %7334, align 8, !dbg !271 + %7336 = getelementptr inbounds nuw i8, ptr addrspace(3) %7334, i32 4, !dbg !271 + %7337 = load float, ptr addrspace(3) %7336, align 4, !dbg !271 + %7338 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5191, !dbg !271 + %7339 = load float, ptr addrspace(3) %7338, align 8, !dbg !271 + %7340 = getelementptr inbounds nuw i8, ptr addrspace(3) %7338, i32 4, !dbg !271 + %7341 = load float, ptr addrspace(3) %7340, align 4, !dbg !271 + %7342 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5197, !dbg !271 + %7343 = load float, ptr addrspace(3) %7342, align 8, !dbg !271 + %7344 = getelementptr inbounds nuw i8, ptr addrspace(3) %7342, i32 4, !dbg !271 + %7345 = load float, ptr addrspace(3) %7344, align 4, !dbg !271 + %7346 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5203, !dbg !271 + %7347 = load float, ptr addrspace(3) %7346, align 8, !dbg !271 + %7348 = getelementptr inbounds nuw i8, ptr addrspace(3) %7346, i32 4, !dbg !271 + %7349 = load float, ptr addrspace(3) %7348, align 4, !dbg !271 + %7350 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5209, !dbg !271 + %7351 = load float, ptr addrspace(3) %7350, align 8, !dbg !271 + %7352 = getelementptr inbounds nuw i8, ptr addrspace(3) %7350, i32 4, !dbg !271 + %7353 = load float, ptr addrspace(3) %7352, align 4, !dbg !271 + %7354 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5215, !dbg !271 + %7355 = load float, ptr addrspace(3) %7354, align 8, !dbg !271 + %7356 = getelementptr inbounds nuw i8, ptr addrspace(3) %7354, i32 4, !dbg !271 + %7357 = load float, ptr addrspace(3) %7356, align 4, !dbg !271 + %7358 = getelementptr inbounds nuw i8, ptr addrspace(3) %7329, i32 %5221, !dbg !271 + %7359 = load float, ptr addrspace(3) %7358, align 8, !dbg !271 + %7360 = getelementptr inbounds nuw i8, ptr addrspace(3) %7358, i32 4, !dbg !271 + %7361 = load float, ptr addrspace(3) %7360, align 4, !dbg !271 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !295 + %7362 = add i32 %6061, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7363 = lshr exact i32 %7362, 4, !dbg !295 + %7364 = and i32 %7363, 16383, !dbg !295 + %7365 = zext nneg i32 %7364 to i64, !dbg !295 + %7366 = or disjoint i64 %7365, 4611686293372403712, !dbg !295 + %7367 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %7366, i64 %7053) #3, !dbg !295 + %7368 = add i32 %6073, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7369 = lshr exact i32 %7368, 4, !dbg !295 + %7370 = and i32 %7369, 16383, !dbg !295 + %7371 = zext nneg i32 %7370 to i64, !dbg !295 + %7372 = or disjoint i64 %7371, 4611686293372403712, !dbg !295 + %7373 = add i32 %7049, 32, !dbg !295 + %7374 = lshr exact i32 %7373, 4, !dbg !295 + %7375 = and i32 %7374, 16383, !dbg !295 + %7376 = zext nneg i32 %7375 to i64, !dbg !295 + %7377 = or disjoint i64 %7376, 4611686293338849280, !dbg !295 + %7378 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 0, !dbg !295 + %7379 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 1, !dbg !295 + %7380 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 2, !dbg !295 + %7381 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 3, !dbg !295 + %7382 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 4, !dbg !295 + %7383 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 5, !dbg !295 + %7384 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 6, !dbg !295 + %7385 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 7, !dbg !295 + %7386 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 8, !dbg !295 + %7387 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 9, !dbg !295 + %7388 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 10, !dbg !295 + %7389 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 11, !dbg !295 + %7390 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 12, !dbg !295 + %7391 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 13, !dbg !295 + %7392 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 14, !dbg !295 + %7393 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 15, !dbg !295 + %7394 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 16, !dbg !295 + %7395 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 17, !dbg !295 + %7396 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 18, !dbg !295 + %7397 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 19, !dbg !295 + %7398 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 20, !dbg !295 + %7399 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 21, !dbg !295 + %7400 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 22, !dbg !295 + %7401 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 23, !dbg !295 + %7402 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 24, !dbg !295 + %7403 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 25, !dbg !295 + %7404 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 26, !dbg !295 + %7405 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 27, !dbg !295 + %7406 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 28, !dbg !295 + %7407 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 29, !dbg !295 + %7408 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 30, !dbg !295 + %7409 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7367, 31, !dbg !295 + %7410 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7378, float %7379, float %7380, float %7381, float %7382, float %7383, float %7384, float %7385, float %7386, float %7387, float %7388, float %7389, float %7390, float %7391, float %7392, float %7393, float %7394, float %7395, float %7396, float %7397, float %7398, float %7399, float %7400, float %7401, float %7402, float %7403, float %7404, float %7405, float %7406, float %7407, float %7408, float %7409, i64 %7372, i64 %7377, i1 true) #3, !dbg !295 + %7411 = add i32 %6117, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7412 = lshr exact i32 %7411, 4, !dbg !295 + %7413 = and i32 %7412, 16383, !dbg !295 + %7414 = zext nneg i32 %7413 to i64, !dbg !295 + %7415 = or disjoint i64 %7414, 4611686293372403712, !dbg !295 + %7416 = add i32 %7049, 64, !dbg !295 + %7417 = lshr exact i32 %7416, 4, !dbg !295 + %7418 = and i32 %7417, 16383, !dbg !295 + %7419 = zext nneg i32 %7418 to i64, !dbg !295 + %7420 = or disjoint i64 %7419, 4611686293338849280, !dbg !295 + %7421 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 0, !dbg !295 + %7422 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 1, !dbg !295 + %7423 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 2, !dbg !295 + %7424 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 3, !dbg !295 + %7425 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 4, !dbg !295 + %7426 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 5, !dbg !295 + %7427 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 6, !dbg !295 + %7428 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 7, !dbg !295 + %7429 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 8, !dbg !295 + %7430 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 9, !dbg !295 + %7431 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 10, !dbg !295 + %7432 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 11, !dbg !295 + %7433 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 12, !dbg !295 + %7434 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 13, !dbg !295 + %7435 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 14, !dbg !295 + %7436 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 15, !dbg !295 + %7437 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 16, !dbg !295 + %7438 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 17, !dbg !295 + %7439 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 18, !dbg !295 + %7440 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 19, !dbg !295 + %7441 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 20, !dbg !295 + %7442 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 21, !dbg !295 + %7443 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 22, !dbg !295 + %7444 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 23, !dbg !295 + %7445 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 24, !dbg !295 + %7446 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 25, !dbg !295 + %7447 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 26, !dbg !295 + %7448 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 27, !dbg !295 + %7449 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 28, !dbg !295 + %7450 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 29, !dbg !295 + %7451 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 30, !dbg !295 + %7452 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7410, 31, !dbg !295 + %7453 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7421, float %7422, float %7423, float %7424, float %7425, float %7426, float %7427, float %7428, float %7429, float %7430, float %7431, float %7432, float %7433, float %7434, float %7435, float %7436, float %7437, float %7438, float %7439, float %7440, float %7441, float %7442, float %7443, float %7444, float %7445, float %7446, float %7447, float %7448, float %7449, float %7450, float %7451, float %7452, i64 %7415, i64 %7420, i1 true) #3, !dbg !295 + %7454 = add i32 %6161, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7455 = lshr exact i32 %7454, 4, !dbg !295 + %7456 = and i32 %7455, 16383, !dbg !295 + %7457 = zext nneg i32 %7456 to i64, !dbg !295 + %7458 = or disjoint i64 %7457, 4611686293372403712, !dbg !295 + %7459 = add i32 %7049, 96, !dbg !295 + %7460 = lshr exact i32 %7459, 4, !dbg !295 + %7461 = and i32 %7460, 16383, !dbg !295 + %7462 = zext nneg i32 %7461 to i64, !dbg !295 + %7463 = or disjoint i64 %7462, 4611686293338849280, !dbg !295 + %7464 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 0, !dbg !295 + %7465 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 1, !dbg !295 + %7466 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 2, !dbg !295 + %7467 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 3, !dbg !295 + %7468 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 4, !dbg !295 + %7469 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 5, !dbg !295 + %7470 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 6, !dbg !295 + %7471 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 7, !dbg !295 + %7472 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 8, !dbg !295 + %7473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 9, !dbg !295 + %7474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 10, !dbg !295 + %7475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 11, !dbg !295 + %7476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 12, !dbg !295 + %7477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 13, !dbg !295 + %7478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 14, !dbg !295 + %7479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 15, !dbg !295 + %7480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 16, !dbg !295 + %7481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 17, !dbg !295 + %7482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 18, !dbg !295 + %7483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 19, !dbg !295 + %7484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 20, !dbg !295 + %7485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 21, !dbg !295 + %7486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 22, !dbg !295 + %7487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 23, !dbg !295 + %7488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 24, !dbg !295 + %7489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 25, !dbg !295 + %7490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 26, !dbg !295 + %7491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 27, !dbg !295 + %7492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 28, !dbg !295 + %7493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 29, !dbg !295 + %7494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 30, !dbg !295 + %7495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7453, 31, !dbg !295 + %7496 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7464, float %7465, float %7466, float %7467, float %7468, float %7469, float %7470, float %7471, float %7472, float %7473, float %7474, float %7475, float %7476, float %7477, float %7478, float %7479, float %7480, float %7481, float %7482, float %7483, float %7484, float %7485, float %7486, float %7487, float %7488, float %7489, float %7490, float %7491, float %7492, float %7493, float %7494, float %7495, i64 %7458, i64 %7463, i1 true) #3, !dbg !295 + %7497 = add i32 %6205, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7498 = lshr exact i32 %7497, 4, !dbg !295 + %7499 = and i32 %7498, 16383, !dbg !295 + %7500 = zext nneg i32 %7499 to i64, !dbg !295 + %7501 = or disjoint i64 %7500, 4611686293372403712, !dbg !295 + %7502 = add i32 %7049, 8192, !dbg !295 + %7503 = lshr exact i32 %7502, 4, !dbg !295 + %7504 = and i32 %7503, 16383, !dbg !295 + %7505 = zext nneg i32 %7504 to i64, !dbg !295 + %7506 = or disjoint i64 %7505, 4611686293338849280, !dbg !295 + %7507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 0, !dbg !295 + %7508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 1, !dbg !295 + %7509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 2, !dbg !295 + %7510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 3, !dbg !295 + %7511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 4, !dbg !295 + %7512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 5, !dbg !295 + %7513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 6, !dbg !295 + %7514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 7, !dbg !295 + %7515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 8, !dbg !295 + %7516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 9, !dbg !295 + %7517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 10, !dbg !295 + %7518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 11, !dbg !295 + %7519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 12, !dbg !295 + %7520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 13, !dbg !295 + %7521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 14, !dbg !295 + %7522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 15, !dbg !295 + %7523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 16, !dbg !295 + %7524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 17, !dbg !295 + %7525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 18, !dbg !295 + %7526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 19, !dbg !295 + %7527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 20, !dbg !295 + %7528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 21, !dbg !295 + %7529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 22, !dbg !295 + %7530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 23, !dbg !295 + %7531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 24, !dbg !295 + %7532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 25, !dbg !295 + %7533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 26, !dbg !295 + %7534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 27, !dbg !295 + %7535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 28, !dbg !295 + %7536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 29, !dbg !295 + %7537 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 30, !dbg !295 + %7538 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7496, 31, !dbg !295 + %7539 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7507, float %7508, float %7509, float %7510, float %7511, float %7512, float %7513, float %7514, float %7515, float %7516, float %7517, float %7518, float %7519, float %7520, float %7521, float %7522, float %7523, float %7524, float %7525, float %7526, float %7527, float %7528, float %7529, float %7530, float %7531, float %7532, float %7533, float %7534, float %7535, float %7536, float %7537, float %7538, i64 %7501, i64 %7506, i1 true) #3, !dbg !295 + %7540 = add i32 %6249, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7541 = lshr exact i32 %7540, 4, !dbg !295 + %7542 = and i32 %7541, 16383, !dbg !295 + %7543 = zext nneg i32 %7542 to i64, !dbg !295 + %7544 = or disjoint i64 %7543, 4611686293372403712, !dbg !295 + %7545 = add i32 %7049, 8224, !dbg !295 + %7546 = lshr exact i32 %7545, 4, !dbg !295 + %7547 = and i32 %7546, 16383, !dbg !295 + %7548 = zext nneg i32 %7547 to i64, !dbg !295 + %7549 = or disjoint i64 %7548, 4611686293338849280, !dbg !295 + %7550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 0, !dbg !295 + %7551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 1, !dbg !295 + %7552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 2, !dbg !295 + %7553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 3, !dbg !295 + %7554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 4, !dbg !295 + %7555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 5, !dbg !295 + %7556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 6, !dbg !295 + %7557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 7, !dbg !295 + %7558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 8, !dbg !295 + %7559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 9, !dbg !295 + %7560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 10, !dbg !295 + %7561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 11, !dbg !295 + %7562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 12, !dbg !295 + %7563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 13, !dbg !295 + %7564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 14, !dbg !295 + %7565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 15, !dbg !295 + %7566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 16, !dbg !295 + %7567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 17, !dbg !295 + %7568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 18, !dbg !295 + %7569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 19, !dbg !295 + %7570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 20, !dbg !295 + %7571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 21, !dbg !295 + %7572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 22, !dbg !295 + %7573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 23, !dbg !295 + %7574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 24, !dbg !295 + %7575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 25, !dbg !295 + %7576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 26, !dbg !295 + %7577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 27, !dbg !295 + %7578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 28, !dbg !295 + %7579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 29, !dbg !295 + %7580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 30, !dbg !295 + %7581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7539, 31, !dbg !295 + %7582 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7550, float %7551, float %7552, float %7553, float %7554, float %7555, float %7556, float %7557, float %7558, float %7559, float %7560, float %7561, float %7562, float %7563, float %7564, float %7565, float %7566, float %7567, float %7568, float %7569, float %7570, float %7571, float %7572, float %7573, float %7574, float %7575, float %7576, float %7577, float %7578, float %7579, float %7580, float %7581, i64 %7544, i64 %7549, i1 true) #3, !dbg !295 + %7583 = add i32 %6293, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7584 = lshr exact i32 %7583, 4, !dbg !295 + %7585 = and i32 %7584, 16383, !dbg !295 + %7586 = zext nneg i32 %7585 to i64, !dbg !295 + %7587 = or disjoint i64 %7586, 4611686293372403712, !dbg !295 + %7588 = add i32 %7049, 8256, !dbg !295 + %7589 = lshr exact i32 %7588, 4, !dbg !295 + %7590 = and i32 %7589, 16383, !dbg !295 + %7591 = zext nneg i32 %7590 to i64, !dbg !295 + %7592 = or disjoint i64 %7591, 4611686293338849280, !dbg !295 + %7593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 0, !dbg !295 + %7594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 1, !dbg !295 + %7595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 2, !dbg !295 + %7596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 3, !dbg !295 + %7597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 4, !dbg !295 + %7598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 5, !dbg !295 + %7599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 6, !dbg !295 + %7600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 7, !dbg !295 + %7601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 8, !dbg !295 + %7602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 9, !dbg !295 + %7603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 10, !dbg !295 + %7604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 11, !dbg !295 + %7605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 12, !dbg !295 + %7606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 13, !dbg !295 + %7607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 14, !dbg !295 + %7608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 15, !dbg !295 + %7609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 16, !dbg !295 + %7610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 17, !dbg !295 + %7611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 18, !dbg !295 + %7612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 19, !dbg !295 + %7613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 20, !dbg !295 + %7614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 21, !dbg !295 + %7615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 22, !dbg !295 + %7616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 23, !dbg !295 + %7617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 24, !dbg !295 + %7618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 25, !dbg !295 + %7619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 26, !dbg !295 + %7620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 27, !dbg !295 + %7621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 28, !dbg !295 + %7622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 29, !dbg !295 + %7623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 30, !dbg !295 + %7624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7582, 31, !dbg !295 + %7625 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7593, float %7594, float %7595, float %7596, float %7597, float %7598, float %7599, float %7600, float %7601, float %7602, float %7603, float %7604, float %7605, float %7606, float %7607, float %7608, float %7609, float %7610, float %7611, float %7612, float %7613, float %7614, float %7615, float %7616, float %7617, float %7618, float %7619, float %7620, float %7621, float %7622, float %7623, float %7624, i64 %7587, i64 %7592, i1 true) #3, !dbg !295 + %7626 = add i32 %6337, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !295 + %7627 = lshr exact i32 %7626, 4, !dbg !295 + %7628 = and i32 %7627, 16383, !dbg !295 + %7629 = zext nneg i32 %7628 to i64, !dbg !295 + %7630 = or disjoint i64 %7629, 4611686293372403712, !dbg !295 + %7631 = add i32 %7049, 8288, !dbg !295 + %7632 = lshr exact i32 %7631, 4, !dbg !295 + %7633 = and i32 %7632, 16383, !dbg !295 + %7634 = zext nneg i32 %7633 to i64, !dbg !295 + %7635 = or disjoint i64 %7634, 4611686293338849280, !dbg !295 + %7636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 0, !dbg !295 + %7637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 1, !dbg !295 + %7638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 2, !dbg !295 + %7639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 3, !dbg !295 + %7640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 4, !dbg !295 + %7641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 5, !dbg !295 + %7642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 6, !dbg !295 + %7643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 7, !dbg !295 + %7644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 8, !dbg !295 + %7645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 9, !dbg !295 + %7646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 10, !dbg !295 + %7647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 11, !dbg !295 + %7648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 12, !dbg !295 + %7649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 13, !dbg !295 + %7650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 14, !dbg !295 + %7651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 15, !dbg !295 + %7652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 16, !dbg !295 + %7653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 17, !dbg !295 + %7654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 18, !dbg !295 + %7655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 19, !dbg !295 + %7656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 20, !dbg !295 + %7657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 21, !dbg !295 + %7658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 22, !dbg !295 + %7659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 23, !dbg !295 + %7660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 24, !dbg !295 + %7661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 25, !dbg !295 + %7662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 26, !dbg !295 + %7663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 27, !dbg !295 + %7664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 28, !dbg !295 + %7665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 29, !dbg !295 + %7666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 30, !dbg !295 + %7667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7625, 31, !dbg !295 + %7668 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %7636, float %7637, float %7638, float %7639, float %7640, float %7641, float %7642, float %7643, float %7644, float %7645, float %7646, float %7647, float %7648, float %7649, float %7650, float %7651, float %7652, float %7653, float %7654, float %7655, float %7656, float %7657, float %7658, float %7659, float %7660, float %7661, float %7662, float %7663, float %7664, float %7665, float %7666, float %7667, i64 %7630, i64 %7635, i1 true) #3, !dbg !295 + %7669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 0, !dbg !295 + %7670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 1, !dbg !295 + %7671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 2, !dbg !295 + %7672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 3, !dbg !295 + %7673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 4, !dbg !295 + %7674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 5, !dbg !295 + %7675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 6, !dbg !295 + %7676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 7, !dbg !295 + %7677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 8, !dbg !295 + %7678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 9, !dbg !295 + %7679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 10, !dbg !295 + %7680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 11, !dbg !295 + %7681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 12, !dbg !295 + %7682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 13, !dbg !295 + %7683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 14, !dbg !295 + %7684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 15, !dbg !295 + %7685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 16, !dbg !295 + %7686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 17, !dbg !295 + %7687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 18, !dbg !295 + %7688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 19, !dbg !295 + %7689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 20, !dbg !295 + %7690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 21, !dbg !295 + %7691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 22, !dbg !295 + %7692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 23, !dbg !295 + %7693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 24, !dbg !295 + %7694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 25, !dbg !295 + %7695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 26, !dbg !295 + %7696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 27, !dbg !295 + %7697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 28, !dbg !295 + %7698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 29, !dbg !295 + %7699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 30, !dbg !295 + %7700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7668, 31, !dbg !295 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !295 + %7701 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %7669, float %7670, float %7671, float %7672, float %7673, float %7674, float %7675, float %7676, float %7677, float %7678, float %7679, float %7680, float %7681, float %7682, float %7683, float %7684, float %7685, float %7686, float %7687, float %7688, float %7689, float %7690, float %7691, float %7692, float %7693, float %7694, float %7695, float %7696, float %7697, float %7698, float %7699, float %7700, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %6984, i32 0, i32 0) #3, !dbg !295 + %7702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 0, !dbg !295 + %7703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 1, !dbg !295 + %7704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 2, !dbg !295 + %7705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 3, !dbg !295 + %7706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 4, !dbg !295 + %7707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 5, !dbg !295 + %7708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 6, !dbg !295 + %7709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 7, !dbg !295 + %7710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 8, !dbg !295 + %7711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 9, !dbg !295 + %7712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 10, !dbg !295 + %7713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 11, !dbg !295 + %7714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 12, !dbg !295 + %7715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 13, !dbg !295 + %7716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 14, !dbg !295 + %7717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 15, !dbg !295 + %7718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 16, !dbg !295 + %7719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 17, !dbg !295 + %7720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 18, !dbg !295 + %7721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 19, !dbg !295 + %7722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 20, !dbg !295 + %7723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 21, !dbg !295 + %7724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 22, !dbg !295 + %7725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 23, !dbg !295 + %7726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 24, !dbg !295 + %7727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 25, !dbg !295 + %7728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 26, !dbg !295 + %7729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 27, !dbg !295 + %7730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 28, !dbg !295 + %7731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 29, !dbg !295 + %7732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 30, !dbg !295 + %7733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %7701, 31, !dbg !295 + %7734 = fsub float %7702, %7331, !dbg !296 + %7735 = fsub float %7703, %7333, !dbg !296 + %7736 = fsub float %7704, %7331, !dbg !296 + %7737 = fsub float %7705, %7333, !dbg !296 + %7738 = fsub float %7706, %7335, !dbg !296 + %7739 = fsub float %7707, %7337, !dbg !296 + %7740 = fsub float %7708, %7335, !dbg !296 + %7741 = fsub float %7709, %7337, !dbg !296 + %7742 = fsub float %7710, %7339, !dbg !296 + %7743 = fsub float %7711, %7341, !dbg !296 + %7744 = fsub float %7712, %7339, !dbg !296 + %7745 = fsub float %7713, %7341, !dbg !296 + %7746 = fsub float %7714, %7343, !dbg !296 + %7747 = fsub float %7715, %7345, !dbg !296 + %7748 = fsub float %7716, %7343, !dbg !296 + %7749 = fsub float %7717, %7345, !dbg !296 + %7750 = fsub float %7718, %7347, !dbg !296 + %7751 = fsub float %7719, %7349, !dbg !296 + %7752 = fsub float %7720, %7347, !dbg !296 + %7753 = fsub float %7721, %7349, !dbg !296 + %7754 = fsub float %7722, %7351, !dbg !296 + %7755 = fsub float %7723, %7353, !dbg !296 + %7756 = fsub float %7724, %7351, !dbg !296 + %7757 = fsub float %7725, %7353, !dbg !296 + %7758 = fsub float %7726, %7355, !dbg !296 + %7759 = fsub float %7727, %7357, !dbg !296 + %7760 = fsub float %7728, %7355, !dbg !296 + %7761 = fsub float %7729, %7357, !dbg !296 + %7762 = fsub float %7730, %7359, !dbg !296 + %7763 = fsub float %7731, %7361, !dbg !296 + %7764 = fsub float %7732, %7359, !dbg !296 + %7765 = fsub float %7733, %7361, !dbg !296 + %7766 = fmul float %.0.i1323, %7734, !dbg !297 + %7767 = fmul float %.0.i1326, %7735, !dbg !297 + %7768 = fmul float %.0.i1329, %7736, !dbg !297 + %7769 = fmul float %.0.i1332, %7737, !dbg !297 + %7770 = fmul float %.0.i1335, %7738, !dbg !297 + %7771 = fmul float %.0.i1338, %7739, !dbg !297 + %7772 = fmul float %.0.i1341, %7740, !dbg !297 + %7773 = fmul float %.0.i1344, %7741, !dbg !297 + %7774 = fmul float %.0.i1347, %7742, !dbg !297 + %7775 = fmul float %.0.i1350, %7743, !dbg !297 + %7776 = fmul float %.0.i1353, %7744, !dbg !297 + %7777 = fmul float %.0.i1356, %7745, !dbg !297 + %7778 = fmul float %.0.i1359, %7746, !dbg !297 + %7779 = fmul float %.0.i1362, %7747, !dbg !297 + %7780 = fmul float %.0.i1365, %7748, !dbg !297 + %7781 = fmul float %.0.i1368, %7749, !dbg !297 + %7782 = fmul float %.0.i1371, %7750, !dbg !297 + %7783 = fmul float %.0.i1374, %7751, !dbg !297 + %7784 = fmul float %.0.i1377, %7752, !dbg !297 + %7785 = fmul float %.0.i1380, %7753, !dbg !297 + %7786 = fmul float %.0.i1383, %7754, !dbg !297 + %7787 = fmul float %.0.i1386, %7755, !dbg !297 + %7788 = fmul float %.0.i1389, %7756, !dbg !297 + %7789 = fmul float %.0.i1392, %7757, !dbg !297 + %7790 = fmul float %.0.i1395, %7758, !dbg !297 + %7791 = fmul float %.0.i1398, %7759, !dbg !297 + %7792 = fmul float %.0.i1401, %7760, !dbg !297 + %7793 = fmul float %.0.i1404, %7761, !dbg !297 + %7794 = fmul float %.0.i1407, %7762, !dbg !297 + %7795 = fmul float %.0.i1410, %7763, !dbg !297 + %7796 = fmul float %.0.i1413, %7764, !dbg !297 + %7797 = fmul float %.0.i1416, %7765, !dbg !297 + %7798 = fptrunc float %7766 to bfloat, !dbg !298 + %7799 = select i1 %6696, bfloat %7798, bfloat 0xR0000, !dbg !299 + %7800 = fptrunc float %7767 to bfloat, !dbg !298 + %7801 = select i1 %6697, bfloat %7800, bfloat 0xR0000, !dbg !299 + %7802 = fptrunc float %7768 to bfloat, !dbg !298 + %7803 = select i1 %6698, bfloat %7802, bfloat 0xR0000, !dbg !299 + %7804 = fptrunc float %7769 to bfloat, !dbg !298 + %7805 = select i1 %6699, bfloat %7804, bfloat 0xR0000, !dbg !299 + %7806 = fptrunc float %7770 to bfloat, !dbg !298 + %7807 = select i1 %6700, bfloat %7806, bfloat 0xR0000, !dbg !299 + %7808 = fptrunc float %7771 to bfloat, !dbg !298 + %7809 = select i1 %6701, bfloat %7808, bfloat 0xR0000, !dbg !299 + %7810 = fptrunc float %7772 to bfloat, !dbg !298 + %7811 = select i1 %6702, bfloat %7810, bfloat 0xR0000, !dbg !299 + %7812 = fptrunc float %7773 to bfloat, !dbg !298 + %7813 = select i1 %6703, bfloat %7812, bfloat 0xR0000, !dbg !299 + %7814 = fptrunc float %7774 to bfloat, !dbg !298 + %7815 = select i1 %6704, bfloat %7814, bfloat 0xR0000, !dbg !299 + %7816 = fptrunc float %7775 to bfloat, !dbg !298 + %7817 = select i1 %6705, bfloat %7816, bfloat 0xR0000, !dbg !299 + %7818 = fptrunc float %7776 to bfloat, !dbg !298 + %7819 = select i1 %6706, bfloat %7818, bfloat 0xR0000, !dbg !299 + %7820 = fptrunc float %7777 to bfloat, !dbg !298 + %7821 = select i1 %6707, bfloat %7820, bfloat 0xR0000, !dbg !299 + %7822 = fptrunc float %7778 to bfloat, !dbg !298 + %7823 = select i1 %6708, bfloat %7822, bfloat 0xR0000, !dbg !299 + %7824 = fptrunc float %7779 to bfloat, !dbg !298 + %7825 = select i1 %6709, bfloat %7824, bfloat 0xR0000, !dbg !299 + %7826 = fptrunc float %7780 to bfloat, !dbg !298 + %7827 = select i1 %6710, bfloat %7826, bfloat 0xR0000, !dbg !299 + %7828 = fptrunc float %7781 to bfloat, !dbg !298 + %7829 = select i1 %6711, bfloat %7828, bfloat 0xR0000, !dbg !299 + %7830 = fptrunc float %7782 to bfloat, !dbg !298 + %7831 = select i1 %6712, bfloat %7830, bfloat 0xR0000, !dbg !299 + %7832 = fptrunc float %7783 to bfloat, !dbg !298 + %7833 = select i1 %6713, bfloat %7832, bfloat 0xR0000, !dbg !299 + %7834 = fptrunc float %7784 to bfloat, !dbg !298 + %7835 = select i1 %6714, bfloat %7834, bfloat 0xR0000, !dbg !299 + %7836 = fptrunc float %7785 to bfloat, !dbg !298 + %7837 = select i1 %6715, bfloat %7836, bfloat 0xR0000, !dbg !299 + %7838 = fptrunc float %7786 to bfloat, !dbg !298 + %7839 = select i1 %6716, bfloat %7838, bfloat 0xR0000, !dbg !299 + %7840 = fptrunc float %7787 to bfloat, !dbg !298 + %7841 = select i1 %6717, bfloat %7840, bfloat 0xR0000, !dbg !299 + %7842 = fptrunc float %7788 to bfloat, !dbg !298 + %7843 = select i1 %6718, bfloat %7842, bfloat 0xR0000, !dbg !299 + %7844 = fptrunc float %7789 to bfloat, !dbg !298 + %7845 = select i1 %6719, bfloat %7844, bfloat 0xR0000, !dbg !299 + %7846 = fptrunc float %7790 to bfloat, !dbg !298 + %7847 = select i1 %6720, bfloat %7846, bfloat 0xR0000, !dbg !299 + %7848 = fptrunc float %7791 to bfloat, !dbg !298 + %7849 = select i1 %6721, bfloat %7848, bfloat 0xR0000, !dbg !299 + %7850 = fptrunc float %7792 to bfloat, !dbg !298 + %7851 = select i1 %6722, bfloat %7850, bfloat 0xR0000, !dbg !299 + %7852 = fptrunc float %7793 to bfloat, !dbg !298 + %7853 = select i1 %6723, bfloat %7852, bfloat 0xR0000, !dbg !299 + %7854 = fptrunc float %7794 to bfloat, !dbg !298 + %7855 = select i1 %6724, bfloat %7854, bfloat 0xR0000, !dbg !299 + %7856 = fptrunc float %7795 to bfloat, !dbg !298 + %7857 = select i1 %6725, bfloat %7856, bfloat 0xR0000, !dbg !299 + %7858 = fptrunc float %7796 to bfloat, !dbg !298 + %7859 = select i1 %6726, bfloat %7858, bfloat 0xR0000, !dbg !299 + %7860 = fptrunc float %7797 to bfloat, !dbg !298 + %7861 = select i1 %6727, bfloat %7860, bfloat 0xR0000, !dbg !299 + %7862 = insertelement <2 x bfloat> poison, bfloat %7799, i64 0, !dbg !300 + %7863 = insertelement <2 x bfloat> %7862, bfloat %7801, i64 1, !dbg !300 + %7864 = bitcast <2 x bfloat> %7863 to i32, !dbg !300 + %7865 = insertelement <2 x bfloat> poison, bfloat %7803, i64 0, !dbg !300 + %7866 = insertelement <2 x bfloat> %7865, bfloat %7805, i64 1, !dbg !300 + %7867 = bitcast <2 x bfloat> %7866 to i32, !dbg !300 + %7868 = insertelement <2 x bfloat> poison, bfloat %7807, i64 0, !dbg !300 + %7869 = insertelement <2 x bfloat> %7868, bfloat %7809, i64 1, !dbg !300 + %7870 = bitcast <2 x bfloat> %7869 to i32, !dbg !300 + %7871 = insertelement <2 x bfloat> poison, bfloat %7811, i64 0, !dbg !300 + %7872 = insertelement <2 x bfloat> %7871, bfloat %7813, i64 1, !dbg !300 + %7873 = bitcast <2 x bfloat> %7872 to i32, !dbg !300 + %7874 = insertelement <2 x bfloat> poison, bfloat %7815, i64 0, !dbg !300 + %7875 = insertelement <2 x bfloat> %7874, bfloat %7817, i64 1, !dbg !300 + %7876 = bitcast <2 x bfloat> %7875 to i32, !dbg !300 + %7877 = insertelement <2 x bfloat> poison, bfloat %7819, i64 0, !dbg !300 + %7878 = insertelement <2 x bfloat> %7877, bfloat %7821, i64 1, !dbg !300 + %7879 = bitcast <2 x bfloat> %7878 to i32, !dbg !300 + %7880 = insertelement <2 x bfloat> poison, bfloat %7823, i64 0, !dbg !300 + %7881 = insertelement <2 x bfloat> %7880, bfloat %7825, i64 1, !dbg !300 + %7882 = bitcast <2 x bfloat> %7881 to i32, !dbg !300 + %7883 = insertelement <2 x bfloat> poison, bfloat %7827, i64 0, !dbg !300 + %7884 = insertelement <2 x bfloat> %7883, bfloat %7829, i64 1, !dbg !300 + %7885 = bitcast <2 x bfloat> %7884 to i32, !dbg !300 + %7886 = insertelement <2 x bfloat> poison, bfloat %7831, i64 0, !dbg !300 + %7887 = insertelement <2 x bfloat> %7886, bfloat %7833, i64 1, !dbg !300 + %7888 = bitcast <2 x bfloat> %7887 to i32, !dbg !300 + %7889 = insertelement <2 x bfloat> poison, bfloat %7835, i64 0, !dbg !300 + %7890 = insertelement <2 x bfloat> %7889, bfloat %7837, i64 1, !dbg !300 + %7891 = bitcast <2 x bfloat> %7890 to i32, !dbg !300 + %7892 = insertelement <2 x bfloat> poison, bfloat %7839, i64 0, !dbg !300 + %7893 = insertelement <2 x bfloat> %7892, bfloat %7841, i64 1, !dbg !300 + %7894 = bitcast <2 x bfloat> %7893 to i32, !dbg !300 + %7895 = insertelement <2 x bfloat> poison, bfloat %7843, i64 0, !dbg !300 + %7896 = insertelement <2 x bfloat> %7895, bfloat %7845, i64 1, !dbg !300 + %7897 = bitcast <2 x bfloat> %7896 to i32, !dbg !300 + %7898 = insertelement <2 x bfloat> poison, bfloat %7847, i64 0, !dbg !300 + %7899 = insertelement <2 x bfloat> %7898, bfloat %7849, i64 1, !dbg !300 + %7900 = bitcast <2 x bfloat> %7899 to i32, !dbg !300 + %7901 = insertelement <2 x bfloat> poison, bfloat %7851, i64 0, !dbg !300 + %7902 = insertelement <2 x bfloat> %7901, bfloat %7853, i64 1, !dbg !300 + %7903 = bitcast <2 x bfloat> %7902 to i32, !dbg !300 + %7904 = insertelement <2 x bfloat> poison, bfloat %7855, i64 0, !dbg !300 + %7905 = insertelement <2 x bfloat> %7904, bfloat %7857, i64 1, !dbg !300 + %7906 = bitcast <2 x bfloat> %7905 to i32, !dbg !300 + %7907 = insertelement <2 x bfloat> poison, bfloat %7859, i64 0, !dbg !300 + %7908 = insertelement <2 x bfloat> %7907, bfloat %7861, i64 1, !dbg !300 + %7909 = bitcast <2 x bfloat> %7908 to i32, !dbg !300 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !300 + %7910 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %5901, float %5902, float %5903, float %5904, float %5905, float %5906, float %5907, float %5908, float %5909, float %5910, float %5911, float %5912, float %5913, float %5914, float %5915, float %5916, float %5917, float %5918, float %5919, float %5920, float %5921, float %5922, float %5923, float %5924, float %5925, float %5926, float %5927, float %5928, float %5929, float %5930, float %5931, float %5932, float %5933, float %5934, float %5935, float %5936, float %5937, float %5938, float %5939, float %5940, float %5941, float %5942, float %5943, float %5944, float %5945, float %5946, float %5947, float %5948, float %5949, float %5950, float %5951, float %5952, float %5953, float %5954, float %5955, float %5956, float %5957, float %5958, float %5959, float %5960, float %5961, float %5962, float %5963, float %5964, i32 %7864, i32 %7867, i32 %7870, i32 %7873, i64 %6071, i1 true) #3, !dbg !300 + %7911 = add i32 %6067, 2048, !dbg !300 + %7912 = lshr exact i32 %7911, 4, !dbg !300 + %7913 = and i32 %7912, 16383, !dbg !300 + %7914 = zext nneg i32 %7913 to i64, !dbg !300 + %7915 = or disjoint i64 %7914, 4611686293338849280, !dbg !300 + %7916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 0, !dbg !300 + %7917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 1, !dbg !300 + %7918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 2, !dbg !300 + %7919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 3, !dbg !300 + %7920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 4, !dbg !300 + %7921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 5, !dbg !300 + %7922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 6, !dbg !300 + %7923 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 7, !dbg !300 + %7924 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 8, !dbg !300 + %7925 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 9, !dbg !300 + %7926 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 10, !dbg !300 + %7927 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 11, !dbg !300 + %7928 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 12, !dbg !300 + %7929 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 13, !dbg !300 + %7930 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 14, !dbg !300 + %7931 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 15, !dbg !300 + %7932 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 16, !dbg !300 + %7933 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 17, !dbg !300 + %7934 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 18, !dbg !300 + %7935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 19, !dbg !300 + %7936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 20, !dbg !300 + %7937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 21, !dbg !300 + %7938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 22, !dbg !300 + %7939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 23, !dbg !300 + %7940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 24, !dbg !300 + %7941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 25, !dbg !300 + %7942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 26, !dbg !300 + %7943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 27, !dbg !300 + %7944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 28, !dbg !300 + %7945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 29, !dbg !300 + %7946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 30, !dbg !300 + %7947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 31, !dbg !300 + %7948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 32, !dbg !300 + %7949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 33, !dbg !300 + %7950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 34, !dbg !300 + %7951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 35, !dbg !300 + %7952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 36, !dbg !300 + %7953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 37, !dbg !300 + %7954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 38, !dbg !300 + %7955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 39, !dbg !300 + %7956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 40, !dbg !300 + %7957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 41, !dbg !300 + %7958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 42, !dbg !300 + %7959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 43, !dbg !300 + %7960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 44, !dbg !300 + %7961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 45, !dbg !300 + %7962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 46, !dbg !300 + %7963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 47, !dbg !300 + %7964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 48, !dbg !300 + %7965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 49, !dbg !300 + %7966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 50, !dbg !300 + %7967 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 51, !dbg !300 + %7968 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 52, !dbg !300 + %7969 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 53, !dbg !300 + %7970 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 54, !dbg !300 + %7971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 55, !dbg !300 + %7972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 56, !dbg !300 + %7973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 57, !dbg !300 + %7974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 58, !dbg !300 + %7975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 59, !dbg !300 + %7976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 60, !dbg !300 + %7977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 61, !dbg !300 + %7978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 62, !dbg !300 + %7979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7910, 63, !dbg !300 + %7980 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7916, float %7917, float %7918, float %7919, float %7920, float %7921, float %7922, float %7923, float %7924, float %7925, float %7926, float %7927, float %7928, float %7929, float %7930, float %7931, float %7932, float %7933, float %7934, float %7935, float %7936, float %7937, float %7938, float %7939, float %7940, float %7941, float %7942, float %7943, float %7944, float %7945, float %7946, float %7947, float %7948, float %7949, float %7950, float %7951, float %7952, float %7953, float %7954, float %7955, float %7956, float %7957, float %7958, float %7959, float %7960, float %7961, float %7962, float %7963, float %7964, float %7965, float %7966, float %7967, float %7968, float %7969, float %7970, float %7971, float %7972, float %7973, float %7974, float %7975, float %7976, float %7977, float %7978, float %7979, i32 %7876, i32 %7879, i32 %7882, i32 %7885, i64 %7915, i1 true) #3, !dbg !300 + %7981 = add i32 %6067, 4096, !dbg !300 + %7982 = lshr exact i32 %7981, 4, !dbg !300 + %7983 = and i32 %7982, 16383, !dbg !300 + %7984 = zext nneg i32 %7983 to i64, !dbg !300 + %7985 = or disjoint i64 %7984, 4611686293338849280, !dbg !300 + %7986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 0, !dbg !300 + %7987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 1, !dbg !300 + %7988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 2, !dbg !300 + %7989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 3, !dbg !300 + %7990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 4, !dbg !300 + %7991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 5, !dbg !300 + %7992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 6, !dbg !300 + %7993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 7, !dbg !300 + %7994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 8, !dbg !300 + %7995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 9, !dbg !300 + %7996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 10, !dbg !300 + %7997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 11, !dbg !300 + %7998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 12, !dbg !300 + %7999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 13, !dbg !300 + %8000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 14, !dbg !300 + %8001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 15, !dbg !300 + %8002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 16, !dbg !300 + %8003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 17, !dbg !300 + %8004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 18, !dbg !300 + %8005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 19, !dbg !300 + %8006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 20, !dbg !300 + %8007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 21, !dbg !300 + %8008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 22, !dbg !300 + %8009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 23, !dbg !300 + %8010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 24, !dbg !300 + %8011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 25, !dbg !300 + %8012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 26, !dbg !300 + %8013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 27, !dbg !300 + %8014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 28, !dbg !300 + %8015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 29, !dbg !300 + %8016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 30, !dbg !300 + %8017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 31, !dbg !300 + %8018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 32, !dbg !300 + %8019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 33, !dbg !300 + %8020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 34, !dbg !300 + %8021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 35, !dbg !300 + %8022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 36, !dbg !300 + %8023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 37, !dbg !300 + %8024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 38, !dbg !300 + %8025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 39, !dbg !300 + %8026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 40, !dbg !300 + %8027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 41, !dbg !300 + %8028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 42, !dbg !300 + %8029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 43, !dbg !300 + %8030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 44, !dbg !300 + %8031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 45, !dbg !300 + %8032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 46, !dbg !300 + %8033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 47, !dbg !300 + %8034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 48, !dbg !300 + %8035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 49, !dbg !300 + %8036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 50, !dbg !300 + %8037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 51, !dbg !300 + %8038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 52, !dbg !300 + %8039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 53, !dbg !300 + %8040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 54, !dbg !300 + %8041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 55, !dbg !300 + %8042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 56, !dbg !300 + %8043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 57, !dbg !300 + %8044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 58, !dbg !300 + %8045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 59, !dbg !300 + %8046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 60, !dbg !300 + %8047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 61, !dbg !300 + %8048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 62, !dbg !300 + %8049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %7980, 63, !dbg !300 + %8050 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %7986, float %7987, float %7988, float %7989, float %7990, float %7991, float %7992, float %7993, float %7994, float %7995, float %7996, float %7997, float %7998, float %7999, float %8000, float %8001, float %8002, float %8003, float %8004, float %8005, float %8006, float %8007, float %8008, float %8009, float %8010, float %8011, float %8012, float %8013, float %8014, float %8015, float %8016, float %8017, float %8018, float %8019, float %8020, float %8021, float %8022, float %8023, float %8024, float %8025, float %8026, float %8027, float %8028, float %8029, float %8030, float %8031, float %8032, float %8033, float %8034, float %8035, float %8036, float %8037, float %8038, float %8039, float %8040, float %8041, float %8042, float %8043, float %8044, float %8045, float %8046, float %8047, float %8048, float %8049, i32 %7888, i32 %7891, i32 %7894, i32 %7897, i64 %7985, i1 true) #3, !dbg !300 + %8051 = add i32 %6067, 6144, !dbg !300 + %8052 = lshr exact i32 %8051, 4, !dbg !300 + %8053 = and i32 %8052, 16383, !dbg !300 + %8054 = zext nneg i32 %8053 to i64, !dbg !300 + %8055 = or disjoint i64 %8054, 4611686293338849280, !dbg !300 + %8056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 0, !dbg !300 + %8057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 1, !dbg !300 + %8058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 2, !dbg !300 + %8059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 3, !dbg !300 + %8060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 4, !dbg !300 + %8061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 5, !dbg !300 + %8062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 6, !dbg !300 + %8063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 7, !dbg !300 + %8064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 8, !dbg !300 + %8065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 9, !dbg !300 + %8066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 10, !dbg !300 + %8067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 11, !dbg !300 + %8068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 12, !dbg !300 + %8069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 13, !dbg !300 + %8070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 14, !dbg !300 + %8071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 15, !dbg !300 + %8072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 16, !dbg !300 + %8073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 17, !dbg !300 + %8074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 18, !dbg !300 + %8075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 19, !dbg !300 + %8076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 20, !dbg !300 + %8077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 21, !dbg !300 + %8078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 22, !dbg !300 + %8079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 23, !dbg !300 + %8080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 24, !dbg !300 + %8081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 25, !dbg !300 + %8082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 26, !dbg !300 + %8083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 27, !dbg !300 + %8084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 28, !dbg !300 + %8085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 29, !dbg !300 + %8086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 30, !dbg !300 + %8087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 31, !dbg !300 + %8088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 32, !dbg !300 + %8089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 33, !dbg !300 + %8090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 34, !dbg !300 + %8091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 35, !dbg !300 + %8092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 36, !dbg !300 + %8093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 37, !dbg !300 + %8094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 38, !dbg !300 + %8095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 39, !dbg !300 + %8096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 40, !dbg !300 + %8097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 41, !dbg !300 + %8098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 42, !dbg !300 + %8099 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 43, !dbg !300 + %8100 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 44, !dbg !300 + %8101 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 45, !dbg !300 + %8102 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 46, !dbg !300 + %8103 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 47, !dbg !300 + %8104 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 48, !dbg !300 + %8105 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 49, !dbg !300 + %8106 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 50, !dbg !300 + %8107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 51, !dbg !300 + %8108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 52, !dbg !300 + %8109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 53, !dbg !300 + %8110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 54, !dbg !300 + %8111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 55, !dbg !300 + %8112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 56, !dbg !300 + %8113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 57, !dbg !300 + %8114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 58, !dbg !300 + %8115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 59, !dbg !300 + %8116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 60, !dbg !300 + %8117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 61, !dbg !300 + %8118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 62, !dbg !300 + %8119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8050, 63, !dbg !300 + %8120 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %8056, float %8057, float %8058, float %8059, float %8060, float %8061, float %8062, float %8063, float %8064, float %8065, float %8066, float %8067, float %8068, float %8069, float %8070, float %8071, float %8072, float %8073, float %8074, float %8075, float %8076, float %8077, float %8078, float %8079, float %8080, float %8081, float %8082, float %8083, float %8084, float %8085, float %8086, float %8087, float %8088, float %8089, float %8090, float %8091, float %8092, float %8093, float %8094, float %8095, float %8096, float %8097, float %8098, float %8099, float %8100, float %8101, float %8102, float %8103, float %8104, float %8105, float %8106, float %8107, float %8108, float %8109, float %8110, float %8111, float %8112, float %8113, float %8114, float %8115, float %8116, float %8117, float %8118, float %8119, i32 %7900, i32 %7903, i32 %7906, i32 %7909, i64 %8055, i1 true) #3, !dbg !300 + %8121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 0, !dbg !300 + %8122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 1, !dbg !300 + %8123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 2, !dbg !300 + %8124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 3, !dbg !300 + %8125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 4, !dbg !300 + %8126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 5, !dbg !300 + %8127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 6, !dbg !300 + %8128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 7, !dbg !300 + %8129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 8, !dbg !300 + %8130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 9, !dbg !300 + %8131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 10, !dbg !300 + %8132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 11, !dbg !300 + %8133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 12, !dbg !300 + %8134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 13, !dbg !300 + %8135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 14, !dbg !300 + %8136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 15, !dbg !300 + %8137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 16, !dbg !300 + %8138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 17, !dbg !300 + %8139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 18, !dbg !300 + %8140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 19, !dbg !300 + %8141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 20, !dbg !300 + %8142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 21, !dbg !300 + %8143 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 22, !dbg !300 + %8144 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 23, !dbg !300 + %8145 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 24, !dbg !300 + %8146 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 25, !dbg !300 + %8147 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 26, !dbg !300 + %8148 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 27, !dbg !300 + %8149 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 28, !dbg !300 + %8150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 29, !dbg !300 + %8151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 30, !dbg !300 + %8152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 31, !dbg !300 + %8153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 32, !dbg !300 + %8154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 33, !dbg !300 + %8155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 34, !dbg !300 + %8156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 35, !dbg !300 + %8157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 36, !dbg !300 + %8158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 37, !dbg !300 + %8159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 38, !dbg !300 + %8160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 39, !dbg !300 + %8161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 40, !dbg !300 + %8162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 41, !dbg !300 + %8163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 42, !dbg !300 + %8164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 43, !dbg !300 + %8165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 44, !dbg !300 + %8166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 45, !dbg !300 + %8167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 46, !dbg !300 + %8168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 47, !dbg !300 + %8169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 48, !dbg !300 + %8170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 49, !dbg !300 + %8171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 50, !dbg !300 + %8172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 51, !dbg !300 + %8173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 52, !dbg !300 + %8174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 53, !dbg !300 + %8175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 54, !dbg !300 + %8176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 55, !dbg !300 + %8177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 56, !dbg !300 + %8178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 57, !dbg !300 + %8179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 58, !dbg !300 + %8180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 59, !dbg !300 + %8181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 60, !dbg !300 + %8182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 61, !dbg !300 + %8183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 62, !dbg !300 + %8184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8120, 63, !dbg !300 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !300 + %8185 = insertelement <16 x i32> poison, i32 %5808, i64 0, !dbg !301 + %8186 = shufflevector <16 x i32> %8185, <16 x i32> poison, <16 x i32> zeroinitializer, !dbg !301 + %8187 = add <16 x i32> %5966, %8186, !dbg !301 + %8188 = add nuw nsw i32 %5965, 1, !dbg !274 + %8189 = lshr i32 %8188, 1, !dbg !302 + %8190 = zext nneg i32 %8189 to i64, !dbg !303 + %8191 = getelementptr i32, ptr addrspace(1) %4982, i64 %8190, !dbg !303 + %8192 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !304 + %8193 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %8191, i64 %8192, i1 %5968) #3, !dbg !304 + %8194 = add nuw nsw i32 %8189, 1, !dbg !305 + %8195 = icmp slt i32 %8194, %4986, !dbg !306 + %8196 = getelementptr i8, ptr addrspace(1) %8191, i64 4, !dbg !307 + %8197 = and i1 %5968, %8195, !dbg !274 + %8198 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !308 + %8199 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %8196, i64 %8198, i1 %8197) #3, !dbg !308 + %8200 = and i32 %5965, 1, !dbg !309 + %8201 = sub i32 %8199, %8193, !dbg !310 + %8202 = shl i32 %8201, 7, !dbg !311 + %8203 = add i32 %8202, -64, !dbg !312 + %8204 = xor i32 %8200, 1, !dbg !313 + %8205 = mul nuw nsw i32 %8203, %8204, !dbg !313 + %8206 = shl nuw nsw i32 %8200, 6, !dbg !314 + %8207 = add i32 %8205, %8206, !dbg !315 + %8208 = shl i32 %8207, 12, !dbg !316 + %8209 = sext i32 %8208 to i64, !dbg !272 + %8210 = getelementptr bfloat, ptr addrspace(1) %.pn1231675, i64 %8209, !dbg !272 + %8211 = getelementptr bfloat, ptr addrspace(1) %.pn1071676, i64 %8209, !dbg !272 + %8212 = getelementptr bfloat, ptr addrspace(1) %.pn911677, i64 %8209, !dbg !272 + %8213 = getelementptr bfloat, ptr addrspace(1) %.pn751678, i64 %8209, !dbg !272 + %8214 = shl i32 %8207, 7, !dbg !317 + %8215 = sext i32 %8214 to i64, !dbg !273 + %8216 = getelementptr bfloat, ptr addrspace(1) %.pn1871679, i64 %8215, !dbg !273 + %8217 = getelementptr bfloat, ptr addrspace(1) %.pn1711680, i64 %8215, !dbg !273 + %8218 = getelementptr bfloat, ptr addrspace(1) %.pn1551681, i64 %8215, !dbg !273 + %8219 = getelementptr bfloat, ptr addrspace(1) %.pn1391682, i64 %8215, !dbg !273 + %8220 = add i32 %8207, %.pn2191683, !dbg !301 + %8221 = add i32 %8207, %.pn2171684, !dbg !301 + %8222 = add i32 %8207, %.pn2151685, !dbg !301 + %8223 = add i32 %8207, %.pn2131686, !dbg !301 + %8224 = add i32 %8207, %.pn2111687, !dbg !301 + %8225 = add i32 %8207, %.pn2091688, !dbg !301 + %8226 = add i32 %8207, %.pn2071689, !dbg !301 + %8227 = add i32 %8207, %.pn2051690, !dbg !301 + %8228 = add i32 %8207, %.pn2031691, !dbg !301 + %8229 = add i32 %8207, %.pn2011692, !dbg !301 + %8230 = add i32 %8207, %.pn1991693, !dbg !301 + %8231 = add i32 %8207, %.pn1971694, !dbg !301 + %8232 = add i32 %8207, %.pn1951695, !dbg !301 + %8233 = add i32 %8207, %.pn1931696, !dbg !301 + %8234 = add i32 %8207, %.pn1911697, !dbg !301 + %8235 = add i32 %8207, %.pn1891698, !dbg !301 + %8236 = add i32 %8207, %5833, !dbg !301 + %8237 = add i32 %8207, %5834, !dbg !301 + %8238 = add i32 %8207, %5835, !dbg !301 + %8239 = add i32 %8207, %5836, !dbg !301 + %8240 = add i32 %8207, %5829, !dbg !301 + %8241 = add i32 %8207, %5830, !dbg !301 + %8242 = add i32 %8207, %5831, !dbg !301 + %8243 = add i32 %8207, %5832, !dbg !301 + %8244 = add i32 %5826, 1, !dbg !274 + %8245 = icmp sgt i32 %8244, 1, !dbg !274 + %8246 = select i1 %8245, i32 0, i32 %8244, !dbg !274 + %8247 = add i32 %5828, 1, !dbg !274 + %8248 = icmp sgt i32 %8247, 2, !dbg !274 + %8249 = select i1 %8248, i32 0, i32 %8247, !dbg !274 + %8250 = icmp slt i32 %8236, %17, !dbg !275 + %8251 = icmp slt i32 %8237, %17, !dbg !275 + %8252 = icmp slt i32 %8238, %17, !dbg !275 + %8253 = icmp slt i32 %8239, %17, !dbg !275 + %8254 = shl i32 %8249, 13, !dbg !266 + %8255 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %8254, !dbg !266 + %8256 = and i1 %5967, %8250, !dbg !274 + %8257 = and i1 %5967, %8251, !dbg !274 + %8258 = and i1 %5967, %8252, !dbg !274 + %8259 = and i1 %5967, %8253, !dbg !274 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !266 + %8260 = getelementptr inbounds nuw i8, ptr addrspace(3) %8255, i32 %5101, !dbg !266 + %8261 = select i1 %8256, i32 16, i32 0, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %8260, ptr addrspace(1) %8210, i32 %8261) #3, !dbg !266 + %8262 = getelementptr inbounds nuw i8, ptr addrspace(3) %8255, i32 %5104, !dbg !266 + %8263 = select i1 %8257, i32 16, i32 0, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8262, ptr addrspace(1) %8211, i32 %8263) #3, !dbg !266 + %8264 = getelementptr inbounds nuw i8, ptr addrspace(3) %8255, i32 %5107, !dbg !266 + %8265 = select i1 %8258, i32 16, i32 0, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8264, ptr addrspace(1) %8212, i32 %8265) #3, !dbg !266 + %8266 = getelementptr inbounds nuw i8, ptr addrspace(3) %8255, i32 %5110, !dbg !266 + %8267 = select i1 %8259, i32 16, i32 0, !dbg !266 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8266, ptr addrspace(1) %8213, i32 %8267) #3, !dbg !266 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !266 + %8268 = icmp slt i32 %8220, %17, !dbg !318 + %8269 = icmp slt i32 %8221, %17, !dbg !318 + %8270 = icmp slt i32 %8222, %17, !dbg !318 + %8271 = icmp slt i32 %8223, %17, !dbg !318 + %8272 = icmp slt i32 %8224, %17, !dbg !318 + %8273 = icmp slt i32 %8225, %17, !dbg !318 + %8274 = icmp slt i32 %8226, %17, !dbg !318 + %8275 = icmp slt i32 %8227, %17, !dbg !318 + %8276 = icmp slt i32 %8228, %17, !dbg !318 + %8277 = icmp slt i32 %8229, %17, !dbg !318 + %8278 = icmp slt i32 %8230, %17, !dbg !318 + %8279 = icmp slt i32 %8231, %17, !dbg !318 + %8280 = icmp slt i32 %8232, %17, !dbg !318 + %8281 = icmp slt i32 %8233, %17, !dbg !318 + %8282 = icmp slt i32 %8234, %17, !dbg !318 + %8283 = icmp slt i32 %8235, %17, !dbg !318 + %8284 = sext i32 %8220 to i64, !dbg !267 + %8285 = getelementptr float, ptr addrspace(1) %5718, i64 %8284, !dbg !267 + %8286 = sext i32 %8221 to i64, !dbg !267 + %8287 = getelementptr float, ptr addrspace(1) %5718, i64 %8286, !dbg !267 + %8288 = sext i32 %8222 to i64, !dbg !267 + %8289 = getelementptr float, ptr addrspace(1) %5718, i64 %8288, !dbg !267 + %8290 = sext i32 %8223 to i64, !dbg !267 + %8291 = getelementptr float, ptr addrspace(1) %5718, i64 %8290, !dbg !267 + %8292 = sext i32 %8224 to i64, !dbg !267 + %8293 = getelementptr float, ptr addrspace(1) %5718, i64 %8292, !dbg !267 + %8294 = sext i32 %8225 to i64, !dbg !267 + %8295 = getelementptr float, ptr addrspace(1) %5718, i64 %8294, !dbg !267 + %8296 = sext i32 %8226 to i64, !dbg !267 + %8297 = getelementptr float, ptr addrspace(1) %5718, i64 %8296, !dbg !267 + %8298 = sext i32 %8227 to i64, !dbg !267 + %8299 = getelementptr float, ptr addrspace(1) %5718, i64 %8298, !dbg !267 + %8300 = sext i32 %8228 to i64, !dbg !267 + %8301 = getelementptr float, ptr addrspace(1) %5718, i64 %8300, !dbg !267 + %8302 = sext i32 %8229 to i64, !dbg !267 + %8303 = getelementptr float, ptr addrspace(1) %5718, i64 %8302, !dbg !267 + %8304 = sext i32 %8230 to i64, !dbg !267 + %8305 = getelementptr float, ptr addrspace(1) %5718, i64 %8304, !dbg !267 + %8306 = sext i32 %8231 to i64, !dbg !267 + %8307 = getelementptr float, ptr addrspace(1) %5718, i64 %8306, !dbg !267 + %8308 = sext i32 %8232 to i64, !dbg !267 + %8309 = getelementptr float, ptr addrspace(1) %5718, i64 %8308, !dbg !267 + %8310 = sext i32 %8233 to i64, !dbg !267 + %8311 = getelementptr float, ptr addrspace(1) %5718, i64 %8310, !dbg !267 + %8312 = sext i32 %8234 to i64, !dbg !267 + %8313 = getelementptr float, ptr addrspace(1) %5718, i64 %8312, !dbg !267 + %8314 = sext i32 %8235 to i64, !dbg !267 + %8315 = getelementptr float, ptr addrspace(1) %5718, i64 %8314, !dbg !267 + %8316 = shl i32 %8246, 6, !dbg !268 + %8317 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %8316, !dbg !268 + %8318 = and i1 %5967, %8268, !dbg !274 + %8319 = and i1 %5967, %8269, !dbg !274 + %8320 = and i1 %5967, %8270, !dbg !274 + %8321 = and i1 %5967, %8271, !dbg !274 + %8322 = and i1 %5967, %8272, !dbg !274 + %8323 = and i1 %5967, %8273, !dbg !274 + %8324 = and i1 %5967, %8274, !dbg !274 + %8325 = and i1 %5967, %8275, !dbg !274 + %8326 = and i1 %5967, %8276, !dbg !274 + %8327 = and i1 %5967, %8277, !dbg !274 + %8328 = and i1 %5967, %8278, !dbg !274 + %8329 = and i1 %5967, %8279, !dbg !274 + %8330 = and i1 %5967, %8280, !dbg !274 + %8331 = and i1 %5967, %8281, !dbg !274 + %8332 = and i1 %5967, %8282, !dbg !274 + %8333 = and i1 %5967, %8283, !dbg !274 + %8334 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5179, !dbg !268 + %8335 = select i1 %8318, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %8334, ptr addrspace(1) %8285, i32 %8335, i1 %5178) #3, !dbg !268 + %8336 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5182, !dbg !268 + %8337 = select i1 %8319, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8336, ptr addrspace(1) %8287, i32 %8337, i1 %5178) #3, !dbg !268 + %8338 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5185, !dbg !268 + %8339 = select i1 %8320, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8338, ptr addrspace(1) %8289, i32 %8339, i1 %5178) #3, !dbg !268 + %8340 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5188, !dbg !268 + %8341 = select i1 %8321, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8340, ptr addrspace(1) %8291, i32 %8341, i1 %5178) #3, !dbg !268 + %8342 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5191, !dbg !268 + %8343 = select i1 %8322, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8342, ptr addrspace(1) %8293, i32 %8343, i1 %5178) #3, !dbg !268 + %8344 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5194, !dbg !268 + %8345 = select i1 %8323, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8344, ptr addrspace(1) %8295, i32 %8345, i1 %5178) #3, !dbg !268 + %8346 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5197, !dbg !268 + %8347 = select i1 %8324, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8346, ptr addrspace(1) %8297, i32 %8347, i1 %5178) #3, !dbg !268 + %8348 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5200, !dbg !268 + %8349 = select i1 %8325, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8348, ptr addrspace(1) %8299, i32 %8349, i1 %5178) #3, !dbg !268 + %8350 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5203, !dbg !268 + %8351 = select i1 %8326, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8350, ptr addrspace(1) %8301, i32 %8351, i1 %5178) #3, !dbg !268 + %8352 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5206, !dbg !268 + %8353 = select i1 %8327, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8352, ptr addrspace(1) %8303, i32 %8353, i1 %5178) #3, !dbg !268 + %8354 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5209, !dbg !268 + %8355 = select i1 %8328, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8354, ptr addrspace(1) %8305, i32 %8355, i1 %5178) #3, !dbg !268 + %8356 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5212, !dbg !268 + %8357 = select i1 %8329, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8356, ptr addrspace(1) %8307, i32 %8357, i1 %5178) #3, !dbg !268 + %8358 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5215, !dbg !268 + %8359 = select i1 %8330, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8358, ptr addrspace(1) %8309, i32 %8359, i1 %5178) #3, !dbg !268 + %8360 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5218, !dbg !268 + %8361 = select i1 %8331, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8360, ptr addrspace(1) %8311, i32 %8361, i1 %5178) #3, !dbg !268 + %8362 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5221, !dbg !268 + %8363 = select i1 %8332, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8362, ptr addrspace(1) %8313, i32 %8363, i1 %5178) #3, !dbg !268 + %8364 = getelementptr inbounds nuw i8, ptr addrspace(3) %8317, i32 %5224, !dbg !268 + %8365 = select i1 %8333, i32 4, i32 0, !dbg !268 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8364, ptr addrspace(1) %8315, i32 %8365, i1 %5178) #3, !dbg !268 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !268 + %8366 = icmp slt i32 %8240, %17, !dbg !319 + %8367 = icmp slt i32 %8241, %17, !dbg !319 + %8368 = icmp slt i32 %8242, %17, !dbg !319 + %8369 = icmp slt i32 %8243, %17, !dbg !319 + %8370 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %8254, !dbg !269 + %8371 = and i1 %5967, %8366, !dbg !274 + %8372 = and i1 %5967, %8367, !dbg !274 + %8373 = and i1 %5967, %8368, !dbg !274 + %8374 = and i1 %5967, %8369, !dbg !274 + %8375 = getelementptr inbounds nuw i8, ptr addrspace(3) %8370, i32 %5101, !dbg !269 + %8376 = select i1 %8371, i32 16, i32 0, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %8375, ptr addrspace(1) %8216, i32 %8376) #3, !dbg !269 + %8377 = getelementptr inbounds nuw i8, ptr addrspace(3) %8370, i32 %5104, !dbg !269 + %8378 = select i1 %8372, i32 16, i32 0, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8377, ptr addrspace(1) %8217, i32 %8378) #3, !dbg !269 + %8379 = getelementptr inbounds nuw i8, ptr addrspace(3) %8370, i32 %5107, !dbg !269 + %8380 = select i1 %8373, i32 16, i32 0, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8379, ptr addrspace(1) %8218, i32 %8380) #3, !dbg !269 + %8381 = getelementptr inbounds nuw i8, ptr addrspace(3) %8370, i32 %5110, !dbg !269 + %8382 = select i1 %8374, i32 16, i32 0, !dbg !269 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %8381, ptr addrspace(1) %8219, i32 %8382) #3, !dbg !269 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !269 + %8383 = getelementptr float, ptr addrspace(1) %5719, i64 %8284, !dbg !270 + %8384 = getelementptr float, ptr addrspace(1) %5719, i64 %8286, !dbg !270 + %8385 = getelementptr float, ptr addrspace(1) %5719, i64 %8288, !dbg !270 + %8386 = getelementptr float, ptr addrspace(1) %5719, i64 %8290, !dbg !270 + %8387 = getelementptr float, ptr addrspace(1) %5719, i64 %8292, !dbg !270 + %8388 = getelementptr float, ptr addrspace(1) %5719, i64 %8294, !dbg !270 + %8389 = getelementptr float, ptr addrspace(1) %5719, i64 %8296, !dbg !270 + %8390 = getelementptr float, ptr addrspace(1) %5719, i64 %8298, !dbg !270 + %8391 = getelementptr float, ptr addrspace(1) %5719, i64 %8300, !dbg !270 + %8392 = getelementptr float, ptr addrspace(1) %5719, i64 %8302, !dbg !270 + %8393 = getelementptr float, ptr addrspace(1) %5719, i64 %8304, !dbg !270 + %8394 = getelementptr float, ptr addrspace(1) %5719, i64 %8306, !dbg !270 + %8395 = getelementptr float, ptr addrspace(1) %5719, i64 %8308, !dbg !270 + %8396 = getelementptr float, ptr addrspace(1) %5719, i64 %8310, !dbg !270 + %8397 = getelementptr float, ptr addrspace(1) %5719, i64 %8312, !dbg !270 + %8398 = getelementptr float, ptr addrspace(1) %5719, i64 %8314, !dbg !270 + %8399 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %8316, !dbg !271 + %8400 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5179, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %8400, ptr addrspace(1) %8383, i32 %8335, i1 %5178) #3, !dbg !271 + %8401 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5182, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8401, ptr addrspace(1) %8384, i32 %8337, i1 %5178) #3, !dbg !271 + %8402 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5185, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8402, ptr addrspace(1) %8385, i32 %8339, i1 %5178) #3, !dbg !271 + %8403 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5188, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8403, ptr addrspace(1) %8386, i32 %8341, i1 %5178) #3, !dbg !271 + %8404 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5191, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8404, ptr addrspace(1) %8387, i32 %8343, i1 %5178) #3, !dbg !271 + %8405 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5194, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8405, ptr addrspace(1) %8388, i32 %8345, i1 %5178) #3, !dbg !271 + %8406 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5197, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8406, ptr addrspace(1) %8389, i32 %8347, i1 %5178) #3, !dbg !271 + %8407 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5200, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8407, ptr addrspace(1) %8390, i32 %8349, i1 %5178) #3, !dbg !271 + %8408 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5203, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8408, ptr addrspace(1) %8391, i32 %8351, i1 %5178) #3, !dbg !271 + %8409 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5206, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8409, ptr addrspace(1) %8392, i32 %8353, i1 %5178) #3, !dbg !271 + %8410 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5209, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8410, ptr addrspace(1) %8393, i32 %8355, i1 %5178) #3, !dbg !271 + %8411 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5212, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8411, ptr addrspace(1) %8394, i32 %8357, i1 %5178) #3, !dbg !271 + %8412 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5215, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8412, ptr addrspace(1) %8395, i32 %8359, i1 %5178) #3, !dbg !271 + %8413 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5218, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8413, ptr addrspace(1) %8396, i32 %8361, i1 %5178) #3, !dbg !271 + %8414 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5221, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8414, ptr addrspace(1) %8397, i32 %8363, i1 %5178) #3, !dbg !271 + %8415 = getelementptr inbounds nuw i8, ptr addrspace(3) %8399, i32 %5224, !dbg !271 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %8415, ptr addrspace(1) %8398, i32 %8365, i1 %5178) #3, !dbg !271 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !271 + %exitcond2266.not = icmp eq i32 %8188, %smax2265, !dbg !274 + br i1 %exitcond2266.not, label %._crit_edge1701, label %.lr.ph1700, !dbg !274 + +._crit_edge1701: ; preds = %__nv_exp2f.exit1417, %5575 + %8416 = phi float [ %5576, %5575 ], [ %8121, %__nv_exp2f.exit1417 ] + %8417 = phi float [ %5577, %5575 ], [ %8122, %__nv_exp2f.exit1417 ] + %8418 = phi float [ %5578, %5575 ], [ %8123, %__nv_exp2f.exit1417 ] + %8419 = phi float [ %5579, %5575 ], [ %8124, %__nv_exp2f.exit1417 ] + %8420 = phi float [ %5580, %5575 ], [ %8125, %__nv_exp2f.exit1417 ] + %8421 = phi float [ %5581, %5575 ], [ %8126, %__nv_exp2f.exit1417 ] + %8422 = phi float [ %5582, %5575 ], [ %8127, %__nv_exp2f.exit1417 ] + %8423 = phi float [ %5583, %5575 ], [ %8128, %__nv_exp2f.exit1417 ] + %8424 = phi float [ %5584, %5575 ], [ %8129, %__nv_exp2f.exit1417 ] + %8425 = phi float [ %5585, %5575 ], [ %8130, %__nv_exp2f.exit1417 ] + %8426 = phi float [ %5586, %5575 ], [ %8131, %__nv_exp2f.exit1417 ] + %8427 = phi float [ %5587, %5575 ], [ %8132, %__nv_exp2f.exit1417 ] + %8428 = phi float [ %5588, %5575 ], [ %8133, %__nv_exp2f.exit1417 ] + %8429 = phi float [ %5589, %5575 ], [ %8134, %__nv_exp2f.exit1417 ] + %8430 = phi float [ %5590, %5575 ], [ %8135, %__nv_exp2f.exit1417 ] + %8431 = phi float [ %5591, %5575 ], [ %8136, %__nv_exp2f.exit1417 ] + %8432 = phi float [ %5592, %5575 ], [ %8137, %__nv_exp2f.exit1417 ] + %8433 = phi float [ %5593, %5575 ], [ %8138, %__nv_exp2f.exit1417 ] + %8434 = phi float [ %5594, %5575 ], [ %8139, %__nv_exp2f.exit1417 ] + %8435 = phi float [ %5595, %5575 ], [ %8140, %__nv_exp2f.exit1417 ] + %8436 = phi float [ %5596, %5575 ], [ %8141, %__nv_exp2f.exit1417 ] + %8437 = phi float [ %5597, %5575 ], [ %8142, %__nv_exp2f.exit1417 ] + %8438 = phi float [ %5598, %5575 ], [ %8143, %__nv_exp2f.exit1417 ] + %8439 = phi float [ %5599, %5575 ], [ %8144, %__nv_exp2f.exit1417 ] + %8440 = phi float [ %5600, %5575 ], [ %8145, %__nv_exp2f.exit1417 ] + %8441 = phi float [ %5601, %5575 ], [ %8146, %__nv_exp2f.exit1417 ] + %8442 = phi float [ %5602, %5575 ], [ %8147, %__nv_exp2f.exit1417 ] + %8443 = phi float [ %5603, %5575 ], [ %8148, %__nv_exp2f.exit1417 ] + %8444 = phi float [ %5604, %5575 ], [ %8149, %__nv_exp2f.exit1417 ] + %8445 = phi float [ %5605, %5575 ], [ %8150, %__nv_exp2f.exit1417 ] + %8446 = phi float [ %5606, %5575 ], [ %8151, %__nv_exp2f.exit1417 ] + %8447 = phi float [ %5607, %5575 ], [ %8152, %__nv_exp2f.exit1417 ] + %8448 = phi float [ %5608, %5575 ], [ %8153, %__nv_exp2f.exit1417 ] + %8449 = phi float [ %5609, %5575 ], [ %8154, %__nv_exp2f.exit1417 ] + %8450 = phi float [ %5610, %5575 ], [ %8155, %__nv_exp2f.exit1417 ] + %8451 = phi float [ %5611, %5575 ], [ %8156, %__nv_exp2f.exit1417 ] + %8452 = phi float [ %5612, %5575 ], [ %8157, %__nv_exp2f.exit1417 ] + %8453 = phi float [ %5613, %5575 ], [ %8158, %__nv_exp2f.exit1417 ] + %8454 = phi float [ %5614, %5575 ], [ %8159, %__nv_exp2f.exit1417 ] + %8455 = phi float [ %5615, %5575 ], [ %8160, %__nv_exp2f.exit1417 ] + %8456 = phi float [ %5616, %5575 ], [ %8161, %__nv_exp2f.exit1417 ] + %8457 = phi float [ %5617, %5575 ], [ %8162, %__nv_exp2f.exit1417 ] + %8458 = phi float [ %5618, %5575 ], [ %8163, %__nv_exp2f.exit1417 ] + %8459 = phi float [ %5619, %5575 ], [ %8164, %__nv_exp2f.exit1417 ] + %8460 = phi float [ %5620, %5575 ], [ %8165, %__nv_exp2f.exit1417 ] + %8461 = phi float [ %5621, %5575 ], [ %8166, %__nv_exp2f.exit1417 ] + %8462 = phi float [ %5622, %5575 ], [ %8167, %__nv_exp2f.exit1417 ] + %8463 = phi float [ %5623, %5575 ], [ %8168, %__nv_exp2f.exit1417 ] + %8464 = phi float [ %5624, %5575 ], [ %8169, %__nv_exp2f.exit1417 ] + %8465 = phi float [ %5625, %5575 ], [ %8170, %__nv_exp2f.exit1417 ] + %8466 = phi float [ %5626, %5575 ], [ %8171, %__nv_exp2f.exit1417 ] + %8467 = phi float [ %5627, %5575 ], [ %8172, %__nv_exp2f.exit1417 ] + %8468 = phi float [ %5628, %5575 ], [ %8173, %__nv_exp2f.exit1417 ] + %8469 = phi float [ %5629, %5575 ], [ %8174, %__nv_exp2f.exit1417 ] + %8470 = phi float [ %5630, %5575 ], [ %8175, %__nv_exp2f.exit1417 ] + %8471 = phi float [ %5631, %5575 ], [ %8176, %__nv_exp2f.exit1417 ] + %8472 = phi float [ %5632, %5575 ], [ %8177, %__nv_exp2f.exit1417 ] + %8473 = phi float [ %5633, %5575 ], [ %8178, %__nv_exp2f.exit1417 ] + %8474 = phi float [ %5634, %5575 ], [ %8179, %__nv_exp2f.exit1417 ] + %8475 = phi float [ %5635, %5575 ], [ %8180, %__nv_exp2f.exit1417 ] + %8476 = phi float [ %5636, %5575 ], [ %8181, %__nv_exp2f.exit1417 ] + %8477 = phi float [ %5637, %5575 ], [ %8182, %__nv_exp2f.exit1417 ] + %8478 = phi float [ %5638, %5575 ], [ %8183, %__nv_exp2f.exit1417 ] + %8479 = phi float [ %5639, %5575 ], [ %8184, %__nv_exp2f.exit1417 ] + %8480 = phi float [ %5640, %5575 ], [ %7265, %__nv_exp2f.exit1417 ] + %8481 = phi float [ %5641, %5575 ], [ %7266, %__nv_exp2f.exit1417 ] + %8482 = phi float [ %5642, %5575 ], [ %7267, %__nv_exp2f.exit1417 ] + %8483 = phi float [ %5643, %5575 ], [ %7268, %__nv_exp2f.exit1417 ] + %8484 = phi float [ %5644, %5575 ], [ %7269, %__nv_exp2f.exit1417 ] + %8485 = phi float [ %5645, %5575 ], [ %7270, %__nv_exp2f.exit1417 ] + %8486 = phi float [ %5646, %5575 ], [ %7271, %__nv_exp2f.exit1417 ] + %8487 = phi float [ %5647, %5575 ], [ %7272, %__nv_exp2f.exit1417 ] + %8488 = phi float [ %5648, %5575 ], [ %7273, %__nv_exp2f.exit1417 ] + %8489 = phi float [ %5649, %5575 ], [ %7274, %__nv_exp2f.exit1417 ] + %8490 = phi float [ %5650, %5575 ], [ %7275, %__nv_exp2f.exit1417 ] + %8491 = phi float [ %5651, %5575 ], [ %7276, %__nv_exp2f.exit1417 ] + %8492 = phi float [ %5652, %5575 ], [ %7277, %__nv_exp2f.exit1417 ] + %8493 = phi float [ %5653, %5575 ], [ %7278, %__nv_exp2f.exit1417 ] + %8494 = phi float [ %5654, %5575 ], [ %7279, %__nv_exp2f.exit1417 ] + %8495 = phi float [ %5655, %5575 ], [ %7280, %__nv_exp2f.exit1417 ] + %8496 = phi float [ %5656, %5575 ], [ %7281, %__nv_exp2f.exit1417 ] + %8497 = phi float [ %5657, %5575 ], [ %7282, %__nv_exp2f.exit1417 ] + %8498 = phi float [ %5658, %5575 ], [ %7283, %__nv_exp2f.exit1417 ] + %8499 = phi float [ %5659, %5575 ], [ %7284, %__nv_exp2f.exit1417 ] + %8500 = phi float [ %5660, %5575 ], [ %7285, %__nv_exp2f.exit1417 ] + %8501 = phi float [ %5661, %5575 ], [ %7286, %__nv_exp2f.exit1417 ] + %8502 = phi float [ %5662, %5575 ], [ %7287, %__nv_exp2f.exit1417 ] + %8503 = phi float [ %5663, %5575 ], [ %7288, %__nv_exp2f.exit1417 ] + %8504 = phi float [ %5664, %5575 ], [ %7289, %__nv_exp2f.exit1417 ] + %8505 = phi float [ %5665, %5575 ], [ %7290, %__nv_exp2f.exit1417 ] + %8506 = phi float [ %5666, %5575 ], [ %7291, %__nv_exp2f.exit1417 ] + %8507 = phi float [ %5667, %5575 ], [ %7292, %__nv_exp2f.exit1417 ] + %8508 = phi float [ %5668, %5575 ], [ %7293, %__nv_exp2f.exit1417 ] + %8509 = phi float [ %5669, %5575 ], [ %7294, %__nv_exp2f.exit1417 ] + %8510 = phi float [ %5670, %5575 ], [ %7295, %__nv_exp2f.exit1417 ] + %8511 = phi float [ %5671, %5575 ], [ %7296, %__nv_exp2f.exit1417 ] + %8512 = phi float [ %5672, %5575 ], [ %7297, %__nv_exp2f.exit1417 ] + %8513 = phi float [ %5673, %5575 ], [ %7298, %__nv_exp2f.exit1417 ] + %8514 = phi float [ %5674, %5575 ], [ %7299, %__nv_exp2f.exit1417 ] + %8515 = phi float [ %5675, %5575 ], [ %7300, %__nv_exp2f.exit1417 ] + %8516 = phi float [ %5676, %5575 ], [ %7301, %__nv_exp2f.exit1417 ] + %8517 = phi float [ %5677, %5575 ], [ %7302, %__nv_exp2f.exit1417 ] + %8518 = phi float [ %5678, %5575 ], [ %7303, %__nv_exp2f.exit1417 ] + %8519 = phi float [ %5679, %5575 ], [ %7304, %__nv_exp2f.exit1417 ] + %8520 = phi float [ %5680, %5575 ], [ %7305, %__nv_exp2f.exit1417 ] + %8521 = phi float [ %5681, %5575 ], [ %7306, %__nv_exp2f.exit1417 ] + %8522 = phi float [ %5682, %5575 ], [ %7307, %__nv_exp2f.exit1417 ] + %8523 = phi float [ %5683, %5575 ], [ %7308, %__nv_exp2f.exit1417 ] + %8524 = phi float [ %5684, %5575 ], [ %7309, %__nv_exp2f.exit1417 ] + %8525 = phi float [ %5685, %5575 ], [ %7310, %__nv_exp2f.exit1417 ] + %8526 = phi float [ %5686, %5575 ], [ %7311, %__nv_exp2f.exit1417 ] + %8527 = phi float [ %5687, %5575 ], [ %7312, %__nv_exp2f.exit1417 ] + %8528 = phi float [ %5688, %5575 ], [ %7313, %__nv_exp2f.exit1417 ] + %8529 = phi float [ %5689, %5575 ], [ %7314, %__nv_exp2f.exit1417 ] + %8530 = phi float [ %5690, %5575 ], [ %7315, %__nv_exp2f.exit1417 ] + %8531 = phi float [ %5691, %5575 ], [ %7316, %__nv_exp2f.exit1417 ] + %8532 = phi float [ %5692, %5575 ], [ %7317, %__nv_exp2f.exit1417 ] + %8533 = phi float [ %5693, %5575 ], [ %7318, %__nv_exp2f.exit1417 ] + %8534 = phi float [ %5694, %5575 ], [ %7319, %__nv_exp2f.exit1417 ] + %8535 = phi float [ %5695, %5575 ], [ %7320, %__nv_exp2f.exit1417 ] + %8536 = phi float [ %5696, %5575 ], [ %7321, %__nv_exp2f.exit1417 ] + %8537 = phi float [ %5697, %5575 ], [ %7322, %__nv_exp2f.exit1417 ] + %8538 = phi float [ %5698, %5575 ], [ %7323, %__nv_exp2f.exit1417 ] + %8539 = phi float [ %5699, %5575 ], [ %7324, %__nv_exp2f.exit1417 ] + %8540 = phi float [ %5700, %5575 ], [ %7325, %__nv_exp2f.exit1417 ] + %8541 = phi float [ %5701, %5575 ], [ %7326, %__nv_exp2f.exit1417 ] + %8542 = phi float [ %5702, %5575 ], [ %7327, %__nv_exp2f.exit1417 ] + %8543 = phi float [ %5703, %5575 ], [ %7328, %__nv_exp2f.exit1417 ] + %8544 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %8480, float %8481, float %8482, float %8483, float %8484, float %8485, float %8486, float %8487, float %8488, float %8489, float %8490, float %8491, float %8492, float %8493, float %8494, float %8495, float %8496, float %8497, float %8498, float %8499, float %8500, float %8501, float %8502, float %8503, float %8504, float %8505, float %8506, float %8507, float %8508, float %8509, float %8510, float %8511, float %8512, float %8513, float %8514, float %8515, float %8516, float %8517, float %8518, float %8519, float %8520, float %8521, float %8522, float %8523, float %8524, float %8525, float %8526, float %8527, float %8528, float %8529, float %8530, float %8531, float %8532, float %8533, float %8534, float %8535, float %8536, float %8537, float %8538, float %8539, float %8540, float %8541, float %8542, float %8543, float %8416, float %8417, float %8418, float %8419, float %8420, float %8421, float %8422, float %8423, float %8424, float %8425, float %8426, float %8427, float %8428, float %8429, float %8430, float %8431, float %8432, float %8433, float %8434, float %8435, float %8436, float %8437, float %8438, float %8439, float %8440, float %8441, float %8442, float %8443, float %8444, float %8445, float %8446, float %8447, float %8448, float %8449, float %8450, float %8451, float %8452, float %8453, float %8454, float %8455, float %8456, float %8457, float %8458, float %8459, float %8460, float %8461, float %8462, float %8463, float %8464, float %8465, float %8466, float %8467, float %8468, float %8469, float %8470, float %8471, float %8472, float %8473, float %8474, float %8475, float %8476, float %8477, float %8478, float %8479) #3, !dbg !274 + %8545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 0, !dbg !274 + %8546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 1, !dbg !274 + %8547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 2, !dbg !274 + %8548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 3, !dbg !274 + %8549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 4, !dbg !274 + %8550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 5, !dbg !274 + %8551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 6, !dbg !274 + %8552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 7, !dbg !274 + %8553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 8, !dbg !274 + %8554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 9, !dbg !274 + %8555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 10, !dbg !274 + %8556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 11, !dbg !274 + %8557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 12, !dbg !274 + %8558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 13, !dbg !274 + %8559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 14, !dbg !274 + %8560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 15, !dbg !274 + %8561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 16, !dbg !274 + %8562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 17, !dbg !274 + %8563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 18, !dbg !274 + %8564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 19, !dbg !274 + %8565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 20, !dbg !274 + %8566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 21, !dbg !274 + %8567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 22, !dbg !274 + %8568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 23, !dbg !274 + %8569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 24, !dbg !274 + %8570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 25, !dbg !274 + %8571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 26, !dbg !274 + %8572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 27, !dbg !274 + %8573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 28, !dbg !274 + %8574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 29, !dbg !274 + %8575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 30, !dbg !274 + %8576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 31, !dbg !274 + %8577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 32, !dbg !274 + %8578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 33, !dbg !274 + %8579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 34, !dbg !274 + %8580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 35, !dbg !274 + %8581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 36, !dbg !274 + %8582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 37, !dbg !274 + %8583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 38, !dbg !274 + %8584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 39, !dbg !274 + %8585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 40, !dbg !274 + %8586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 41, !dbg !274 + %8587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 42, !dbg !274 + %8588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 43, !dbg !274 + %8589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 44, !dbg !274 + %8590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 45, !dbg !274 + %8591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 46, !dbg !274 + %8592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 47, !dbg !274 + %8593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 48, !dbg !274 + %8594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 49, !dbg !274 + %8595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 50, !dbg !274 + %8596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 51, !dbg !274 + %8597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 52, !dbg !274 + %8598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 53, !dbg !274 + %8599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 54, !dbg !274 + %8600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 55, !dbg !274 + %8601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 56, !dbg !274 + %8602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 57, !dbg !274 + %8603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 58, !dbg !274 + %8604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 59, !dbg !274 + %8605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 60, !dbg !274 + %8606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 61, !dbg !274 + %8607 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 62, !dbg !274 + %8608 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 63, !dbg !274 + %8609 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 64, !dbg !274 + %8610 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 65, !dbg !274 + %8611 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 66, !dbg !274 + %8612 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 67, !dbg !274 + %8613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 68, !dbg !274 + %8614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 69, !dbg !274 + %8615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 70, !dbg !274 + %8616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 71, !dbg !274 + %8617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 72, !dbg !274 + %8618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 73, !dbg !274 + %8619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 74, !dbg !274 + %8620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 75, !dbg !274 + %8621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 76, !dbg !274 + %8622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 77, !dbg !274 + %8623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 78, !dbg !274 + %8624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 79, !dbg !274 + %8625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 80, !dbg !274 + %8626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 81, !dbg !274 + %8627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 82, !dbg !274 + %8628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 83, !dbg !274 + %8629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 84, !dbg !274 + %8630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 85, !dbg !274 + %8631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 86, !dbg !274 + %8632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 87, !dbg !274 + %8633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 88, !dbg !274 + %8634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 89, !dbg !274 + %8635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 90, !dbg !274 + %8636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 91, !dbg !274 + %8637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 92, !dbg !274 + %8638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 93, !dbg !274 + %8639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 94, !dbg !274 + %8640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 95, !dbg !274 + %8641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 96, !dbg !274 + %8642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 97, !dbg !274 + %8643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 98, !dbg !274 + %8644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 99, !dbg !274 + %8645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 100, !dbg !274 + %8646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 101, !dbg !274 + %8647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 102, !dbg !274 + %8648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 103, !dbg !274 + %8649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 104, !dbg !274 + %8650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 105, !dbg !274 + %8651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 106, !dbg !274 + %8652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 107, !dbg !274 + %8653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 108, !dbg !274 + %8654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 109, !dbg !274 + %8655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 110, !dbg !274 + %8656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 111, !dbg !274 + %8657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 112, !dbg !274 + %8658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 113, !dbg !274 + %8659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 114, !dbg !274 + %8660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 115, !dbg !274 + %8661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 116, !dbg !274 + %8662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 117, !dbg !274 + %8663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 118, !dbg !274 + %8664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 119, !dbg !274 + %8665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 120, !dbg !274 + %8666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 121, !dbg !274 + %8667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 122, !dbg !274 + %8668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 123, !dbg !274 + %8669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 124, !dbg !274 + %8670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 125, !dbg !274 + %8671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 126, !dbg !274 + %8672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8544, 127, !dbg !274 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !274 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !274 + %8673 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5386, !dbg !320 + %8674 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5387, !dbg !320 + %8675 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5388, !dbg !320 + %8676 = getelementptr bfloat, ptr addrspace(1) %5716, i64 %5389, !dbg !320 + %8677 = getelementptr bfloat, ptr addrspace(1) %8673, i64 %4768, !dbg !321 + %8678 = getelementptr bfloat, ptr addrspace(1) %8674, i64 %4768, !dbg !321 + %8679 = getelementptr bfloat, ptr addrspace(1) %8675, i64 %4768, !dbg !321 + %8680 = getelementptr bfloat, ptr addrspace(1) %8676, i64 %4768, !dbg !321 + %8681 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5390, !dbg !322 + %8682 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5391, !dbg !322 + %8683 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5392, !dbg !322 + %8684 = getelementptr bfloat, ptr addrspace(1) %5717, i64 %5393, !dbg !322 + %8685 = getelementptr bfloat, ptr addrspace(1) %8681, i64 %4768, !dbg !323 + %8686 = getelementptr bfloat, ptr addrspace(1) %8682, i64 %4768, !dbg !323 + %8687 = getelementptr bfloat, ptr addrspace(1) %8683, i64 %4768, !dbg !323 + %8688 = getelementptr bfloat, ptr addrspace(1) %8684, i64 %4768, !dbg !323 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5102, ptr addrspace(1) %8677, i32 %5403) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5105, ptr addrspace(1) %8678, i32 %5404) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5108, ptr addrspace(1) %8679, i32 %5405) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5111, ptr addrspace(1) %8680, i32 %5406) #3, !dbg !324 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !324 + %8689 = getelementptr float, ptr addrspace(1) %5718, i64 %5423, !dbg !325 + %8690 = getelementptr float, ptr addrspace(1) %5718, i64 %5424, !dbg !325 + %8691 = getelementptr float, ptr addrspace(1) %5718, i64 %5425, !dbg !325 + %8692 = getelementptr float, ptr addrspace(1) %5718, i64 %5426, !dbg !325 + %8693 = getelementptr float, ptr addrspace(1) %5718, i64 %5427, !dbg !325 + %8694 = getelementptr float, ptr addrspace(1) %5718, i64 %5428, !dbg !325 + %8695 = getelementptr float, ptr addrspace(1) %5718, i64 %5429, !dbg !325 + %8696 = getelementptr float, ptr addrspace(1) %5718, i64 %5430, !dbg !325 + %8697 = getelementptr float, ptr addrspace(1) %5718, i64 %5431, !dbg !325 + %8698 = getelementptr float, ptr addrspace(1) %5718, i64 %5432, !dbg !325 + %8699 = getelementptr float, ptr addrspace(1) %5718, i64 %5433, !dbg !325 + %8700 = getelementptr float, ptr addrspace(1) %5718, i64 %5434, !dbg !325 + %8701 = getelementptr float, ptr addrspace(1) %5718, i64 %5435, !dbg !325 + %8702 = getelementptr float, ptr addrspace(1) %5718, i64 %5436, !dbg !325 + %8703 = getelementptr float, ptr addrspace(1) %5718, i64 %5437, !dbg !325 + %8704 = getelementptr float, ptr addrspace(1) %5718, i64 %5438, !dbg !325 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5180, ptr addrspace(1) %8689, i32 %5455, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5183, ptr addrspace(1) %8690, i32 %5456, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5186, ptr addrspace(1) %8691, i32 %5457, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5189, ptr addrspace(1) %8692, i32 %5458, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5192, ptr addrspace(1) %8693, i32 %5459, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5195, ptr addrspace(1) %8694, i32 %5460, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5198, ptr addrspace(1) %8695, i32 %5461, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5201, ptr addrspace(1) %8696, i32 %5462, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5204, ptr addrspace(1) %8697, i32 %5463, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5207, ptr addrspace(1) %8698, i32 %5464, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5210, ptr addrspace(1) %8699, i32 %5465, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5213, ptr addrspace(1) %8700, i32 %5466, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5216, ptr addrspace(1) %8701, i32 %5467, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5219, ptr addrspace(1) %8702, i32 %5468, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5222, ptr addrspace(1) %8703, i32 %5469, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5225, ptr addrspace(1) %8704, i32 %5470, i1 %5178) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5227, ptr addrspace(1) %8685, i32 %5403) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5228, ptr addrspace(1) %8686, i32 %5404) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5229, ptr addrspace(1) %8687, i32 %5405) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5230, ptr addrspace(1) %8688, i32 %5406) #3, !dbg !327 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !327 + %8705 = getelementptr float, ptr addrspace(1) %5719, i64 %5423, !dbg !328 + %8706 = getelementptr float, ptr addrspace(1) %5719, i64 %5424, !dbg !328 + %8707 = getelementptr float, ptr addrspace(1) %5719, i64 %5425, !dbg !328 + %8708 = getelementptr float, ptr addrspace(1) %5719, i64 %5426, !dbg !328 + %8709 = getelementptr float, ptr addrspace(1) %5719, i64 %5427, !dbg !328 + %8710 = getelementptr float, ptr addrspace(1) %5719, i64 %5428, !dbg !328 + %8711 = getelementptr float, ptr addrspace(1) %5719, i64 %5429, !dbg !328 + %8712 = getelementptr float, ptr addrspace(1) %5719, i64 %5430, !dbg !328 + %8713 = getelementptr float, ptr addrspace(1) %5719, i64 %5431, !dbg !328 + %8714 = getelementptr float, ptr addrspace(1) %5719, i64 %5432, !dbg !328 + %8715 = getelementptr float, ptr addrspace(1) %5719, i64 %5433, !dbg !328 + %8716 = getelementptr float, ptr addrspace(1) %5719, i64 %5434, !dbg !328 + %8717 = getelementptr float, ptr addrspace(1) %5719, i64 %5435, !dbg !328 + %8718 = getelementptr float, ptr addrspace(1) %5719, i64 %5436, !dbg !328 + %8719 = getelementptr float, ptr addrspace(1) %5719, i64 %5437, !dbg !328 + %8720 = getelementptr float, ptr addrspace(1) %5719, i64 %5438, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5231, ptr addrspace(1) %8705, i32 %5455, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5232, ptr addrspace(1) %8706, i32 %5456, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5233, ptr addrspace(1) %8707, i32 %5457, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5234, ptr addrspace(1) %8708, i32 %5458, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5235, ptr addrspace(1) %8709, i32 %5459, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5236, ptr addrspace(1) %8710, i32 %5460, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5237, ptr addrspace(1) %8711, i32 %5461, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5238, ptr addrspace(1) %8712, i32 %5462, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5239, ptr addrspace(1) %8713, i32 %5463, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5240, ptr addrspace(1) %8714, i32 %5464, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5241, ptr addrspace(1) %8715, i32 %5465, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5242, ptr addrspace(1) %8716, i32 %5466, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5243, ptr addrspace(1) %8717, i32 %5467, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5244, ptr addrspace(1) %8718, i32 %5468, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5245, ptr addrspace(1) %8719, i32 %5469, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5246, ptr addrspace(1) %8720, i32 %5470, i1 %5178) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + %8721 = getelementptr i8, ptr addrspace(1) %8677, i64 524288, !dbg !330 + %8722 = getelementptr i8, ptr addrspace(1) %8678, i64 524288, !dbg !330 + %8723 = getelementptr i8, ptr addrspace(1) %8679, i64 524288, !dbg !330 + %8724 = getelementptr i8, ptr addrspace(1) %8680, i64 524288, !dbg !330 + %8725 = getelementptr i8, ptr addrspace(1) %8685, i64 16384, !dbg !331 + %8726 = getelementptr i8, ptr addrspace(1) %8686, i64 16384, !dbg !331 + %8727 = getelementptr i8, ptr addrspace(1) %8687, i64 16384, !dbg !331 + %8728 = getelementptr i8, ptr addrspace(1) %8688, i64 16384, !dbg !331 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5276, ptr addrspace(1) %8721, i32 %5500) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5278, ptr addrspace(1) %8722, i32 %5501) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5280, ptr addrspace(1) %8723, i32 %5502) #3, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5282, ptr addrspace(1) %8724, i32 %5503) #3, !dbg !324 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !324 + %8729 = getelementptr float, ptr addrspace(1) %5718, i64 %5520, !dbg !325 + %8730 = getelementptr float, ptr addrspace(1) %5718, i64 %5521, !dbg !325 + %8731 = getelementptr float, ptr addrspace(1) %5718, i64 %5522, !dbg !325 + %8732 = getelementptr float, ptr addrspace(1) %5718, i64 %5523, !dbg !325 + %8733 = getelementptr float, ptr addrspace(1) %5718, i64 %5524, !dbg !325 + %8734 = getelementptr float, ptr addrspace(1) %5718, i64 %5525, !dbg !325 + %8735 = getelementptr float, ptr addrspace(1) %5718, i64 %5526, !dbg !325 + %8736 = getelementptr float, ptr addrspace(1) %5718, i64 %5527, !dbg !325 + %8737 = getelementptr float, ptr addrspace(1) %5718, i64 %5528, !dbg !325 + %8738 = getelementptr float, ptr addrspace(1) %5718, i64 %5529, !dbg !325 + %8739 = getelementptr float, ptr addrspace(1) %5718, i64 %5530, !dbg !325 + %8740 = getelementptr float, ptr addrspace(1) %5718, i64 %5531, !dbg !325 + %8741 = getelementptr float, ptr addrspace(1) %5718, i64 %5532, !dbg !325 + %8742 = getelementptr float, ptr addrspace(1) %5718, i64 %5533, !dbg !325 + %8743 = getelementptr float, ptr addrspace(1) %5718, i64 %5534, !dbg !325 + %8744 = getelementptr float, ptr addrspace(1) %5718, i64 %5535, !dbg !325 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5332, ptr addrspace(1) %8729, i32 %5552, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5334, ptr addrspace(1) %8730, i32 %5553, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5336, ptr addrspace(1) %8731, i32 %5554, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5338, ptr addrspace(1) %8732, i32 %5555, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5340, ptr addrspace(1) %8733, i32 %5556, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5342, ptr addrspace(1) %8734, i32 %5557, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5344, ptr addrspace(1) %8735, i32 %5558, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5346, ptr addrspace(1) %8736, i32 %5559, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5348, ptr addrspace(1) %8737, i32 %5560, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5350, ptr addrspace(1) %8738, i32 %5561, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5352, ptr addrspace(1) %8739, i32 %5562, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5354, ptr addrspace(1) %8740, i32 %5563, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5356, ptr addrspace(1) %8741, i32 %5564, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5358, ptr addrspace(1) %8742, i32 %5565, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5360, ptr addrspace(1) %8743, i32 %5566, i1 %5178) #3, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5362, ptr addrspace(1) %8744, i32 %5567, i1 %5178) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %5364, ptr addrspace(1) %8725, i32 %5500) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5365, ptr addrspace(1) %8726, i32 %5501) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5366, ptr addrspace(1) %8727, i32 %5502) #3, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %5367, ptr addrspace(1) %8728, i32 %5503) #3, !dbg !327 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !327 + %8745 = getelementptr float, ptr addrspace(1) %5719, i64 %5520, !dbg !328 + %8746 = getelementptr float, ptr addrspace(1) %5719, i64 %5521, !dbg !328 + %8747 = getelementptr float, ptr addrspace(1) %5719, i64 %5522, !dbg !328 + %8748 = getelementptr float, ptr addrspace(1) %5719, i64 %5523, !dbg !328 + %8749 = getelementptr float, ptr addrspace(1) %5719, i64 %5524, !dbg !328 + %8750 = getelementptr float, ptr addrspace(1) %5719, i64 %5525, !dbg !328 + %8751 = getelementptr float, ptr addrspace(1) %5719, i64 %5526, !dbg !328 + %8752 = getelementptr float, ptr addrspace(1) %5719, i64 %5527, !dbg !328 + %8753 = getelementptr float, ptr addrspace(1) %5719, i64 %5528, !dbg !328 + %8754 = getelementptr float, ptr addrspace(1) %5719, i64 %5529, !dbg !328 + %8755 = getelementptr float, ptr addrspace(1) %5719, i64 %5530, !dbg !328 + %8756 = getelementptr float, ptr addrspace(1) %5719, i64 %5531, !dbg !328 + %8757 = getelementptr float, ptr addrspace(1) %5719, i64 %5532, !dbg !328 + %8758 = getelementptr float, ptr addrspace(1) %5719, i64 %5533, !dbg !328 + %8759 = getelementptr float, ptr addrspace(1) %5719, i64 %5534, !dbg !328 + %8760 = getelementptr float, ptr addrspace(1) %5719, i64 %5535, !dbg !328 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %5368, ptr addrspace(1) %8745, i32 %5552, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5369, ptr addrspace(1) %8746, i32 %5553, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5370, ptr addrspace(1) %8747, i32 %5554, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5371, ptr addrspace(1) %8748, i32 %5555, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5372, ptr addrspace(1) %8749, i32 %5556, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5373, ptr addrspace(1) %8750, i32 %5557, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5374, ptr addrspace(1) %8751, i32 %5558, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5375, ptr addrspace(1) %8752, i32 %5559, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5376, ptr addrspace(1) %8753, i32 %5560, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5377, ptr addrspace(1) %8754, i32 %5561, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5378, ptr addrspace(1) %8755, i32 %5562, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5379, ptr addrspace(1) %8756, i32 %5563, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5380, ptr addrspace(1) %8757, i32 %5564, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5381, ptr addrspace(1) %8758, i32 %5565, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5382, ptr addrspace(1) %8759, i32 %5566, i1 %5178) #3, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %5383, ptr addrspace(1) %8760, i32 %5567, i1 %5178) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + br i1 %5394, label %.lr.ph1873, label %._crit_edge1874, !dbg !332 + +.lr.ph1873: ; preds = %._crit_edge1701, %__nv_exp2f.exit1321 + %.pn605.pn1871 = phi i32 [ %.pn6051855, %__nv_exp2f.exit1321 ], [ %5068, %._crit_edge1701 ] + %.pn607.pn1870 = phi i32 [ %.pn6071854, %__nv_exp2f.exit1321 ], [ %5066, %._crit_edge1701 ] + %.pn609.pn1869 = phi i32 [ %.pn6091853, %__nv_exp2f.exit1321 ], [ %5064, %._crit_edge1701 ] + %.pn611.pn1868 = phi i32 [ %.pn6111852, %__nv_exp2f.exit1321 ], [ %5062, %._crit_edge1701 ] + %.pn613.pn1867 = phi i32 [ %.pn6131851, %__nv_exp2f.exit1321 ], [ %5060, %._crit_edge1701 ] + %.pn615.pn1866 = phi i32 [ %.pn6151850, %__nv_exp2f.exit1321 ], [ %5058, %._crit_edge1701 ] + %.pn617.pn1865 = phi i32 [ %.pn6171849, %__nv_exp2f.exit1321 ], [ %5056, %._crit_edge1701 ] + %.pn619.pn1864 = phi i32 [ %.pn6191848, %__nv_exp2f.exit1321 ], [ %5054, %._crit_edge1701 ] + %.pn621.pn1863 = phi i32 [ %.pn6211847, %__nv_exp2f.exit1321 ], [ %5052, %._crit_edge1701 ] + %.pn623.pn1862 = phi i32 [ %.pn6231846, %__nv_exp2f.exit1321 ], [ %5050, %._crit_edge1701 ] + %.pn625.pn1861 = phi i32 [ %.pn6251845, %__nv_exp2f.exit1321 ], [ %5048, %._crit_edge1701 ] + %.pn627.pn1860 = phi i32 [ %.pn6271844, %__nv_exp2f.exit1321 ], [ %5046, %._crit_edge1701 ] + %.pn629.pn1859 = phi i32 [ %.pn6291843, %__nv_exp2f.exit1321 ], [ %5044, %._crit_edge1701 ] + %.pn631.pn1858 = phi i32 [ %.pn6311842, %__nv_exp2f.exit1321 ], [ %5043, %._crit_edge1701 ] + %.pn633.pn1857 = phi i32 [ %.pn6331841, %__nv_exp2f.exit1321 ], [ %5042, %._crit_edge1701 ] + %.pn635.pn1856 = phi i32 [ %.pn6351840, %__nv_exp2f.exit1321 ], [ %5041, %._crit_edge1701 ] + %8761 = phi i32 [ %8778, %__nv_exp2f.exit1321 ], [ -1, %._crit_edge1701 ] + %8762 = phi i32 [ %10800, %__nv_exp2f.exit1321 ], [ 1, %._crit_edge1701 ] + %8763 = phi i32 [ %8781, %__nv_exp2f.exit1321 ], [ -1, %._crit_edge1701 ] + %8764 = phi i32 [ %10803, %__nv_exp2f.exit1321 ], [ 1, %._crit_edge1701 ] + %.pn6051855 = phi i32 [ %10789, %__nv_exp2f.exit1321 ], [ %5487, %._crit_edge1701 ] + %.pn6071854 = phi i32 [ %10788, %__nv_exp2f.exit1321 ], [ %5486, %._crit_edge1701 ] + %.pn6091853 = phi i32 [ %10787, %__nv_exp2f.exit1321 ], [ %5485, %._crit_edge1701 ] + %.pn6111852 = phi i32 [ %10786, %__nv_exp2f.exit1321 ], [ %5484, %._crit_edge1701 ] + %.pn6131851 = phi i32 [ %10785, %__nv_exp2f.exit1321 ], [ %5483, %._crit_edge1701 ] + %.pn6151850 = phi i32 [ %10784, %__nv_exp2f.exit1321 ], [ %5482, %._crit_edge1701 ] + %.pn6171849 = phi i32 [ %10783, %__nv_exp2f.exit1321 ], [ %5481, %._crit_edge1701 ] + %.pn6191848 = phi i32 [ %10782, %__nv_exp2f.exit1321 ], [ %5480, %._crit_edge1701 ] + %.pn6211847 = phi i32 [ %10781, %__nv_exp2f.exit1321 ], [ %5479, %._crit_edge1701 ] + %.pn6231846 = phi i32 [ %10780, %__nv_exp2f.exit1321 ], [ %5478, %._crit_edge1701 ] + %.pn6251845 = phi i32 [ %10779, %__nv_exp2f.exit1321 ], [ %5477, %._crit_edge1701 ] + %.pn6271844 = phi i32 [ %10778, %__nv_exp2f.exit1321 ], [ %5476, %._crit_edge1701 ] + %.pn6291843 = phi i32 [ %10777, %__nv_exp2f.exit1321 ], [ %5475, %._crit_edge1701 ] + %.pn6311842 = phi i32 [ %10776, %__nv_exp2f.exit1321 ], [ %5474, %._crit_edge1701 ] + %.pn6331841 = phi i32 [ %10775, %__nv_exp2f.exit1321 ], [ %5473, %._crit_edge1701 ] + %.pn6351840 = phi i32 [ %10774, %__nv_exp2f.exit1321 ], [ %5472, %._crit_edge1701 ] + %8765 = phi i32 [ %10794, %__nv_exp2f.exit1321 ], [ %5488, %._crit_edge1701 ] + %8766 = phi i32 [ %10795, %__nv_exp2f.exit1321 ], [ %5489, %._crit_edge1701 ] + %8767 = phi i32 [ %10796, %__nv_exp2f.exit1321 ], [ %5490, %._crit_edge1701 ] + %8768 = phi i32 [ %10797, %__nv_exp2f.exit1321 ], [ %5491, %._crit_edge1701 ] + %.pn5551839 = phi ptr addrspace(1) [ %10773, %__nv_exp2f.exit1321 ], [ %8728, %._crit_edge1701 ] + %.pn5711838 = phi ptr addrspace(1) [ %10772, %__nv_exp2f.exit1321 ], [ %8727, %._crit_edge1701 ] + %.pn5871837 = phi ptr addrspace(1) [ %10771, %__nv_exp2f.exit1321 ], [ %8726, %._crit_edge1701 ] + %.pn6031836 = phi ptr addrspace(1) [ %10770, %__nv_exp2f.exit1321 ], [ %8725, %._crit_edge1701 ] + %8769 = phi i32 [ %10790, %__nv_exp2f.exit1321 ], [ %5488, %._crit_edge1701 ] + %8770 = phi i32 [ %10791, %__nv_exp2f.exit1321 ], [ %5489, %._crit_edge1701 ] + %8771 = phi i32 [ %10792, %__nv_exp2f.exit1321 ], [ %5490, %._crit_edge1701 ] + %8772 = phi i32 [ %10793, %__nv_exp2f.exit1321 ], [ %5491, %._crit_edge1701 ] + %.pn4911835 = phi ptr addrspace(1) [ %10767, %__nv_exp2f.exit1321 ], [ %8724, %._crit_edge1701 ] + %.pn5071834 = phi ptr addrspace(1) [ %10766, %__nv_exp2f.exit1321 ], [ %8723, %._crit_edge1701 ] + %.pn5231833 = phi ptr addrspace(1) [ %10765, %__nv_exp2f.exit1321 ], [ %8722, %._crit_edge1701 ] + %.pn5391832 = phi ptr addrspace(1) [ %10764, %__nv_exp2f.exit1321 ], [ %8721, %._crit_edge1701 ] + %.pn3491831 = phi float [ %9885, %__nv_exp2f.exit1321 ], [ %8608, %._crit_edge1701 ] + %.pn3511830 = phi float [ %9884, %__nv_exp2f.exit1321 ], [ %8607, %._crit_edge1701 ] + %.pn3531829 = phi float [ %9883, %__nv_exp2f.exit1321 ], [ %8606, %._crit_edge1701 ] + %.pn3551828 = phi float [ %9882, %__nv_exp2f.exit1321 ], [ %8605, %._crit_edge1701 ] + %.pn3571827 = phi float [ %9881, %__nv_exp2f.exit1321 ], [ %8604, %._crit_edge1701 ] + %.pn3591826 = phi float [ %9880, %__nv_exp2f.exit1321 ], [ %8603, %._crit_edge1701 ] + %.pn3611825 = phi float [ %9879, %__nv_exp2f.exit1321 ], [ %8602, %._crit_edge1701 ] + %.pn3631824 = phi float [ %9878, %__nv_exp2f.exit1321 ], [ %8601, %._crit_edge1701 ] + %.pn3651823 = phi float [ %9877, %__nv_exp2f.exit1321 ], [ %8600, %._crit_edge1701 ] + %.pn3671822 = phi float [ %9876, %__nv_exp2f.exit1321 ], [ %8599, %._crit_edge1701 ] + %.pn3691821 = phi float [ %9875, %__nv_exp2f.exit1321 ], [ %8598, %._crit_edge1701 ] + %.pn3711820 = phi float [ %9874, %__nv_exp2f.exit1321 ], [ %8597, %._crit_edge1701 ] + %.pn3731819 = phi float [ %9873, %__nv_exp2f.exit1321 ], [ %8596, %._crit_edge1701 ] + %.pn3751818 = phi float [ %9872, %__nv_exp2f.exit1321 ], [ %8595, %._crit_edge1701 ] + %.pn3771817 = phi float [ %9871, %__nv_exp2f.exit1321 ], [ %8594, %._crit_edge1701 ] + %.pn3791816 = phi float [ %9870, %__nv_exp2f.exit1321 ], [ %8593, %._crit_edge1701 ] + %.pn3811815 = phi float [ %9869, %__nv_exp2f.exit1321 ], [ %8592, %._crit_edge1701 ] + %.pn3831814 = phi float [ %9868, %__nv_exp2f.exit1321 ], [ %8591, %._crit_edge1701 ] + %.pn3851813 = phi float [ %9867, %__nv_exp2f.exit1321 ], [ %8590, %._crit_edge1701 ] + %.pn3871812 = phi float [ %9866, %__nv_exp2f.exit1321 ], [ %8589, %._crit_edge1701 ] + %.pn3891811 = phi float [ %9865, %__nv_exp2f.exit1321 ], [ %8588, %._crit_edge1701 ] + %.pn3911810 = phi float [ %9864, %__nv_exp2f.exit1321 ], [ %8587, %._crit_edge1701 ] + %.pn3931809 = phi float [ %9863, %__nv_exp2f.exit1321 ], [ %8586, %._crit_edge1701 ] + %.pn3951808 = phi float [ %9862, %__nv_exp2f.exit1321 ], [ %8585, %._crit_edge1701 ] + %.pn3971807 = phi float [ %9861, %__nv_exp2f.exit1321 ], [ %8584, %._crit_edge1701 ] + %.pn3991806 = phi float [ %9860, %__nv_exp2f.exit1321 ], [ %8583, %._crit_edge1701 ] + %.pn4011805 = phi float [ %9859, %__nv_exp2f.exit1321 ], [ %8582, %._crit_edge1701 ] + %.pn4031804 = phi float [ %9858, %__nv_exp2f.exit1321 ], [ %8581, %._crit_edge1701 ] + %.pn4051803 = phi float [ %9857, %__nv_exp2f.exit1321 ], [ %8580, %._crit_edge1701 ] + %.pn4071802 = phi float [ %9856, %__nv_exp2f.exit1321 ], [ %8579, %._crit_edge1701 ] + %.pn4091801 = phi float [ %9855, %__nv_exp2f.exit1321 ], [ %8578, %._crit_edge1701 ] + %.pn4111800 = phi float [ %9854, %__nv_exp2f.exit1321 ], [ %8577, %._crit_edge1701 ] + %.pn4131799 = phi float [ %9853, %__nv_exp2f.exit1321 ], [ %8576, %._crit_edge1701 ] + %.pn4151798 = phi float [ %9852, %__nv_exp2f.exit1321 ], [ %8575, %._crit_edge1701 ] + %.pn4171797 = phi float [ %9851, %__nv_exp2f.exit1321 ], [ %8574, %._crit_edge1701 ] + %.pn4191796 = phi float [ %9850, %__nv_exp2f.exit1321 ], [ %8573, %._crit_edge1701 ] + %.pn4211795 = phi float [ %9849, %__nv_exp2f.exit1321 ], [ %8572, %._crit_edge1701 ] + %.pn4231794 = phi float [ %9848, %__nv_exp2f.exit1321 ], [ %8571, %._crit_edge1701 ] + %.pn4251793 = phi float [ %9847, %__nv_exp2f.exit1321 ], [ %8570, %._crit_edge1701 ] + %.pn4271792 = phi float [ %9846, %__nv_exp2f.exit1321 ], [ %8569, %._crit_edge1701 ] + %.pn4291791 = phi float [ %9845, %__nv_exp2f.exit1321 ], [ %8568, %._crit_edge1701 ] + %.pn4311790 = phi float [ %9844, %__nv_exp2f.exit1321 ], [ %8567, %._crit_edge1701 ] + %.pn4331789 = phi float [ %9843, %__nv_exp2f.exit1321 ], [ %8566, %._crit_edge1701 ] + %.pn4351788 = phi float [ %9842, %__nv_exp2f.exit1321 ], [ %8565, %._crit_edge1701 ] + %.pn4371787 = phi float [ %9841, %__nv_exp2f.exit1321 ], [ %8564, %._crit_edge1701 ] + %.pn4391786 = phi float [ %9840, %__nv_exp2f.exit1321 ], [ %8563, %._crit_edge1701 ] + %.pn4411785 = phi float [ %9839, %__nv_exp2f.exit1321 ], [ %8562, %._crit_edge1701 ] + %.pn4431784 = phi float [ %9838, %__nv_exp2f.exit1321 ], [ %8561, %._crit_edge1701 ] + %.pn4451783 = phi float [ %9837, %__nv_exp2f.exit1321 ], [ %8560, %._crit_edge1701 ] + %.pn4471782 = phi float [ %9836, %__nv_exp2f.exit1321 ], [ %8559, %._crit_edge1701 ] + %.pn4491781 = phi float [ %9835, %__nv_exp2f.exit1321 ], [ %8558, %._crit_edge1701 ] + %.pn4511780 = phi float [ %9834, %__nv_exp2f.exit1321 ], [ %8557, %._crit_edge1701 ] + %.pn4531779 = phi float [ %9833, %__nv_exp2f.exit1321 ], [ %8556, %._crit_edge1701 ] + %.pn4551778 = phi float [ %9832, %__nv_exp2f.exit1321 ], [ %8555, %._crit_edge1701 ] + %.pn4571777 = phi float [ %9831, %__nv_exp2f.exit1321 ], [ %8554, %._crit_edge1701 ] + %.pn4591776 = phi float [ %9830, %__nv_exp2f.exit1321 ], [ %8553, %._crit_edge1701 ] + %.pn4611775 = phi float [ %9829, %__nv_exp2f.exit1321 ], [ %8552, %._crit_edge1701 ] + %.pn4631774 = phi float [ %9828, %__nv_exp2f.exit1321 ], [ %8551, %._crit_edge1701 ] + %.pn4651773 = phi float [ %9827, %__nv_exp2f.exit1321 ], [ %8550, %._crit_edge1701 ] + %.pn4671772 = phi float [ %9826, %__nv_exp2f.exit1321 ], [ %8549, %._crit_edge1701 ] + %.pn4691771 = phi float [ %9825, %__nv_exp2f.exit1321 ], [ %8548, %._crit_edge1701 ] + %.pn4711770 = phi float [ %9824, %__nv_exp2f.exit1321 ], [ %8547, %._crit_edge1701 ] + %.pn4731769 = phi float [ %9823, %__nv_exp2f.exit1321 ], [ %8546, %._crit_edge1701 ] + %.pn4751768 = phi float [ %9822, %__nv_exp2f.exit1321 ], [ %8545, %._crit_edge1701 ] + %.pn2211767 = phi float [ %10741, %__nv_exp2f.exit1321 ], [ %8672, %._crit_edge1701 ] + %.pn2231766 = phi float [ %10740, %__nv_exp2f.exit1321 ], [ %8671, %._crit_edge1701 ] + %.pn2251765 = phi float [ %10739, %__nv_exp2f.exit1321 ], [ %8670, %._crit_edge1701 ] + %.pn2271764 = phi float [ %10738, %__nv_exp2f.exit1321 ], [ %8669, %._crit_edge1701 ] + %.pn2291763 = phi float [ %10737, %__nv_exp2f.exit1321 ], [ %8668, %._crit_edge1701 ] + %.pn2311762 = phi float [ %10736, %__nv_exp2f.exit1321 ], [ %8667, %._crit_edge1701 ] + %.pn2331761 = phi float [ %10735, %__nv_exp2f.exit1321 ], [ %8666, %._crit_edge1701 ] + %.pn2351760 = phi float [ %10734, %__nv_exp2f.exit1321 ], [ %8665, %._crit_edge1701 ] + %.pn2371759 = phi float [ %10733, %__nv_exp2f.exit1321 ], [ %8664, %._crit_edge1701 ] + %.pn2391758 = phi float [ %10732, %__nv_exp2f.exit1321 ], [ %8663, %._crit_edge1701 ] + %.pn2411757 = phi float [ %10731, %__nv_exp2f.exit1321 ], [ %8662, %._crit_edge1701 ] + %.pn2431756 = phi float [ %10730, %__nv_exp2f.exit1321 ], [ %8661, %._crit_edge1701 ] + %.pn2451755 = phi float [ %10729, %__nv_exp2f.exit1321 ], [ %8660, %._crit_edge1701 ] + %.pn2471754 = phi float [ %10728, %__nv_exp2f.exit1321 ], [ %8659, %._crit_edge1701 ] + %.pn2491753 = phi float [ %10727, %__nv_exp2f.exit1321 ], [ %8658, %._crit_edge1701 ] + %.pn2511752 = phi float [ %10726, %__nv_exp2f.exit1321 ], [ %8657, %._crit_edge1701 ] + %.pn2531751 = phi float [ %10725, %__nv_exp2f.exit1321 ], [ %8656, %._crit_edge1701 ] + %.pn2551750 = phi float [ %10724, %__nv_exp2f.exit1321 ], [ %8655, %._crit_edge1701 ] + %.pn2571749 = phi float [ %10723, %__nv_exp2f.exit1321 ], [ %8654, %._crit_edge1701 ] + %.pn2591748 = phi float [ %10722, %__nv_exp2f.exit1321 ], [ %8653, %._crit_edge1701 ] + %.pn2611747 = phi float [ %10721, %__nv_exp2f.exit1321 ], [ %8652, %._crit_edge1701 ] + %.pn2631746 = phi float [ %10720, %__nv_exp2f.exit1321 ], [ %8651, %._crit_edge1701 ] + %.pn2651745 = phi float [ %10719, %__nv_exp2f.exit1321 ], [ %8650, %._crit_edge1701 ] + %.pn2671744 = phi float [ %10718, %__nv_exp2f.exit1321 ], [ %8649, %._crit_edge1701 ] + %.pn2691743 = phi float [ %10717, %__nv_exp2f.exit1321 ], [ %8648, %._crit_edge1701 ] + %.pn2711742 = phi float [ %10716, %__nv_exp2f.exit1321 ], [ %8647, %._crit_edge1701 ] + %.pn2731741 = phi float [ %10715, %__nv_exp2f.exit1321 ], [ %8646, %._crit_edge1701 ] + %.pn2751740 = phi float [ %10714, %__nv_exp2f.exit1321 ], [ %8645, %._crit_edge1701 ] + %.pn2771739 = phi float [ %10713, %__nv_exp2f.exit1321 ], [ %8644, %._crit_edge1701 ] + %.pn2791738 = phi float [ %10712, %__nv_exp2f.exit1321 ], [ %8643, %._crit_edge1701 ] + %.pn2811737 = phi float [ %10711, %__nv_exp2f.exit1321 ], [ %8642, %._crit_edge1701 ] + %.pn2831736 = phi float [ %10710, %__nv_exp2f.exit1321 ], [ %8641, %._crit_edge1701 ] + %.pn2851735 = phi float [ %10709, %__nv_exp2f.exit1321 ], [ %8640, %._crit_edge1701 ] + %.pn2871734 = phi float [ %10708, %__nv_exp2f.exit1321 ], [ %8639, %._crit_edge1701 ] + %.pn2891733 = phi float [ %10707, %__nv_exp2f.exit1321 ], [ %8638, %._crit_edge1701 ] + %.pn2911732 = phi float [ %10706, %__nv_exp2f.exit1321 ], [ %8637, %._crit_edge1701 ] + %.pn2931731 = phi float [ %10705, %__nv_exp2f.exit1321 ], [ %8636, %._crit_edge1701 ] + %.pn2951730 = phi float [ %10704, %__nv_exp2f.exit1321 ], [ %8635, %._crit_edge1701 ] + %.pn2971729 = phi float [ %10703, %__nv_exp2f.exit1321 ], [ %8634, %._crit_edge1701 ] + %.pn2991728 = phi float [ %10702, %__nv_exp2f.exit1321 ], [ %8633, %._crit_edge1701 ] + %.pn3011727 = phi float [ %10701, %__nv_exp2f.exit1321 ], [ %8632, %._crit_edge1701 ] + %.pn3031726 = phi float [ %10700, %__nv_exp2f.exit1321 ], [ %8631, %._crit_edge1701 ] + %.pn3051725 = phi float [ %10699, %__nv_exp2f.exit1321 ], [ %8630, %._crit_edge1701 ] + %.pn3071724 = phi float [ %10698, %__nv_exp2f.exit1321 ], [ %8629, %._crit_edge1701 ] + %.pn3091723 = phi float [ %10697, %__nv_exp2f.exit1321 ], [ %8628, %._crit_edge1701 ] + %.pn3111722 = phi float [ %10696, %__nv_exp2f.exit1321 ], [ %8627, %._crit_edge1701 ] + %.pn3131721 = phi float [ %10695, %__nv_exp2f.exit1321 ], [ %8626, %._crit_edge1701 ] + %.pn3151720 = phi float [ %10694, %__nv_exp2f.exit1321 ], [ %8625, %._crit_edge1701 ] + %.pn3171719 = phi float [ %10693, %__nv_exp2f.exit1321 ], [ %8624, %._crit_edge1701 ] + %.pn3191718 = phi float [ %10692, %__nv_exp2f.exit1321 ], [ %8623, %._crit_edge1701 ] + %.pn3211717 = phi float [ %10691, %__nv_exp2f.exit1321 ], [ %8622, %._crit_edge1701 ] + %.pn3231716 = phi float [ %10690, %__nv_exp2f.exit1321 ], [ %8621, %._crit_edge1701 ] + %.pn3251715 = phi float [ %10689, %__nv_exp2f.exit1321 ], [ %8620, %._crit_edge1701 ] + %.pn3271714 = phi float [ %10688, %__nv_exp2f.exit1321 ], [ %8619, %._crit_edge1701 ] + %.pn3291713 = phi float [ %10687, %__nv_exp2f.exit1321 ], [ %8618, %._crit_edge1701 ] + %.pn3311712 = phi float [ %10686, %__nv_exp2f.exit1321 ], [ %8617, %._crit_edge1701 ] + %.pn3331711 = phi float [ %10685, %__nv_exp2f.exit1321 ], [ %8616, %._crit_edge1701 ] + %.pn3351710 = phi float [ %10684, %__nv_exp2f.exit1321 ], [ %8615, %._crit_edge1701 ] + %.pn3371709 = phi float [ %10683, %__nv_exp2f.exit1321 ], [ %8614, %._crit_edge1701 ] + %.pn3391708 = phi float [ %10682, %__nv_exp2f.exit1321 ], [ %8613, %._crit_edge1701 ] + %.pn3411707 = phi float [ %10681, %__nv_exp2f.exit1321 ], [ %8612, %._crit_edge1701 ] + %.pn3431706 = phi float [ %10680, %__nv_exp2f.exit1321 ], [ %8611, %._crit_edge1701 ] + %.pn3451705 = phi float [ %10679, %__nv_exp2f.exit1321 ], [ %8610, %._crit_edge1701 ] + %.pn3471704 = phi float [ %10678, %__nv_exp2f.exit1321 ], [ %8609, %._crit_edge1701 ] + %8773 = phi i32 [ %10742, %__nv_exp2f.exit1321 ], [ 0, %._crit_edge1701 ] + %8774 = icmp slt i32 %8773, %5568, !dbg !332 + %8775 = icmp slt i32 %8773, %5569, !dbg !332 + %8776 = add i32 %8761, 1, !dbg !332 + %8777 = icmp sgt i32 %8776, 1, !dbg !332 + %8778 = select i1 %8777, i32 0, i32 %8776, !dbg !332 + %8779 = add i32 %8763, 1, !dbg !332 + %8780 = icmp sgt i32 %8779, 2, !dbg !332 + %8781 = select i1 %8780, i32 0, i32 %8779, !dbg !332 + %8782 = icmp slt i32 %.pn635.pn1856, %17, !dbg !333 + %8783 = icmp slt i32 %.pn633.pn1857, %17, !dbg !333 + %8784 = icmp slt i32 %.pn631.pn1858, %17, !dbg !333 + %8785 = icmp slt i32 %.pn629.pn1859, %17, !dbg !333 + %8786 = icmp slt i32 %.pn627.pn1860, %17, !dbg !333 + %8787 = icmp slt i32 %.pn625.pn1861, %17, !dbg !333 + %8788 = icmp slt i32 %.pn623.pn1862, %17, !dbg !333 + %8789 = icmp slt i32 %.pn621.pn1863, %17, !dbg !333 + %8790 = icmp slt i32 %.pn619.pn1864, %17, !dbg !333 + %8791 = icmp slt i32 %.pn617.pn1865, %17, !dbg !333 + %8792 = icmp slt i32 %.pn615.pn1866, %17, !dbg !333 + %8793 = icmp slt i32 %.pn613.pn1867, %17, !dbg !333 + %8794 = icmp slt i32 %.pn611.pn1868, %17, !dbg !333 + %8795 = icmp slt i32 %.pn609.pn1869, %17, !dbg !333 + %8796 = icmp slt i32 %.pn607.pn1870, %17, !dbg !333 + %8797 = icmp slt i32 %.pn605.pn1871, %17, !dbg !333 + tail call void @llvm.nvvm.cp.async.wait.group(i32 4), !dbg !324 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !324 + %8798 = shl i32 %8781, 13, !dbg !324 + %8799 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %8798, !dbg !324 + %8800 = shl i32 %8778, 6, !dbg !326 + %8801 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %8800, !dbg !326 + %8802 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5179, !dbg !326 + %8803 = load float, ptr addrspace(3) %8802, align 8, !dbg !326 + %8804 = getelementptr inbounds nuw i8, ptr addrspace(3) %8802, i32 4, !dbg !326 + %8805 = load float, ptr addrspace(3) %8804, align 4, !dbg !326 + %8806 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5185, !dbg !326 + %8807 = load float, ptr addrspace(3) %8806, align 8, !dbg !326 + %8808 = getelementptr inbounds nuw i8, ptr addrspace(3) %8806, i32 4, !dbg !326 + %8809 = load float, ptr addrspace(3) %8808, align 4, !dbg !326 + %8810 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5191, !dbg !326 + %8811 = load float, ptr addrspace(3) %8810, align 8, !dbg !326 + %8812 = getelementptr inbounds nuw i8, ptr addrspace(3) %8810, i32 4, !dbg !326 + %8813 = load float, ptr addrspace(3) %8812, align 4, !dbg !326 + %8814 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5197, !dbg !326 + %8815 = load float, ptr addrspace(3) %8814, align 8, !dbg !326 + %8816 = getelementptr inbounds nuw i8, ptr addrspace(3) %8814, i32 4, !dbg !326 + %8817 = load float, ptr addrspace(3) %8816, align 4, !dbg !326 + %8818 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5203, !dbg !326 + %8819 = load float, ptr addrspace(3) %8818, align 8, !dbg !326 + %8820 = getelementptr inbounds nuw i8, ptr addrspace(3) %8818, i32 4, !dbg !326 + %8821 = load float, ptr addrspace(3) %8820, align 4, !dbg !326 + %8822 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5209, !dbg !326 + %8823 = load float, ptr addrspace(3) %8822, align 8, !dbg !326 + %8824 = getelementptr inbounds nuw i8, ptr addrspace(3) %8822, i32 4, !dbg !326 + %8825 = load float, ptr addrspace(3) %8824, align 4, !dbg !326 + %8826 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5215, !dbg !326 + %8827 = load float, ptr addrspace(3) %8826, align 8, !dbg !326 + %8828 = getelementptr inbounds nuw i8, ptr addrspace(3) %8826, i32 4, !dbg !326 + %8829 = load float, ptr addrspace(3) %8828, align 4, !dbg !326 + %8830 = getelementptr inbounds nuw i8, ptr addrspace(3) %8801, i32 %5221, !dbg !326 + %8831 = load float, ptr addrspace(3) %8830, align 8, !dbg !326 + %8832 = getelementptr inbounds nuw i8, ptr addrspace(3) %8830, i32 4, !dbg !326 + %8833 = load float, ptr addrspace(3) %8832, align 4, !dbg !326 + %8834 = fcmp oeq float %8803, 0xFFF0000000000000, !dbg !334 + %8835 = fcmp oeq float %8805, 0xFFF0000000000000, !dbg !334 + %8836 = fcmp oeq float %8807, 0xFFF0000000000000, !dbg !334 + %8837 = fcmp oeq float %8809, 0xFFF0000000000000, !dbg !334 + %8838 = fcmp oeq float %8811, 0xFFF0000000000000, !dbg !334 + %8839 = fcmp oeq float %8813, 0xFFF0000000000000, !dbg !334 + %8840 = fcmp oeq float %8815, 0xFFF0000000000000, !dbg !334 + %8841 = fcmp oeq float %8817, 0xFFF0000000000000, !dbg !334 + %8842 = fcmp oeq float %8819, 0xFFF0000000000000, !dbg !334 + %8843 = fcmp oeq float %8821, 0xFFF0000000000000, !dbg !334 + %8844 = fcmp oeq float %8823, 0xFFF0000000000000, !dbg !334 + %8845 = fcmp oeq float %8825, 0xFFF0000000000000, !dbg !334 + %8846 = fcmp oeq float %8827, 0xFFF0000000000000, !dbg !334 + %8847 = fcmp oeq float %8829, 0xFFF0000000000000, !dbg !334 + %8848 = fcmp oeq float %8831, 0xFFF0000000000000, !dbg !334 + %8849 = fcmp oeq float %8833, 0xFFF0000000000000, !dbg !334 + %8850 = select i1 %8834, float 0.000000e+00, float %8803, !dbg !335 + %8851 = select i1 %8835, float 0.000000e+00, float %8805, !dbg !335 + %8852 = select i1 %8836, float 0.000000e+00, float %8807, !dbg !335 + %8853 = select i1 %8837, float 0.000000e+00, float %8809, !dbg !335 + %8854 = select i1 %8838, float 0.000000e+00, float %8811, !dbg !335 + %8855 = select i1 %8839, float 0.000000e+00, float %8813, !dbg !335 + %8856 = select i1 %8840, float 0.000000e+00, float %8815, !dbg !335 + %8857 = select i1 %8841, float 0.000000e+00, float %8817, !dbg !335 + %8858 = select i1 %8842, float 0.000000e+00, float %8819, !dbg !335 + %8859 = select i1 %8843, float 0.000000e+00, float %8821, !dbg !335 + %8860 = select i1 %8844, float 0.000000e+00, float %8823, !dbg !335 + %8861 = select i1 %8845, float 0.000000e+00, float %8825, !dbg !335 + %8862 = select i1 %8846, float 0.000000e+00, float %8827, !dbg !335 + %8863 = select i1 %8847, float 0.000000e+00, float %8829, !dbg !335 + %8864 = select i1 %8848, float 0.000000e+00, float %8831, !dbg !335 + %8865 = select i1 %8849, float 0.000000e+00, float %8833, !dbg !335 + %8866 = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 %45, i32 0, i32 31), !dbg !336 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !336 + %8867 = shl i32 %8866, 11, !dbg !336 + %8868 = and i32 %8867, 8192, !dbg !336 + %8869 = add i32 %8868, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %8870 = lshr exact i32 %8869, 4, !dbg !336 + %8871 = and i32 %8870, 16383, !dbg !336 + %8872 = zext nneg i32 %8871 to i64, !dbg !336 + %8873 = or disjoint i64 %8872, 4611686293372403712, !dbg !336 + %8874 = ptrtoint ptr addrspace(3) %8799 to i32, !dbg !336 + %8875 = lshr exact i32 %8874, 4, !dbg !336 + %8876 = and i32 %8875, 16383, !dbg !336 + %8877 = zext nneg i32 %8876 to i64, !dbg !336 + %8878 = or disjoint i64 %8877, 4611686293338849280, !dbg !336 + %8879 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %8873, i64 %8878) #3, !dbg !336 + %8880 = or disjoint i32 %8868, 32, !dbg !336 + %8881 = add i32 %8880, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %8882 = lshr exact i32 %8881, 4, !dbg !336 + %8883 = and i32 %8882, 16383, !dbg !336 + %8884 = zext nneg i32 %8883 to i64, !dbg !336 + %8885 = or disjoint i64 %8884, 4611686293372403712, !dbg !336 + %8886 = add i32 %8874, 32, !dbg !336 + %8887 = lshr exact i32 %8886, 4, !dbg !336 + %8888 = and i32 %8887, 16383, !dbg !336 + %8889 = zext nneg i32 %8888 to i64, !dbg !336 + %8890 = or disjoint i64 %8889, 4611686293338849280, !dbg !336 + %8891 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 0, !dbg !336 + %8892 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 1, !dbg !336 + %8893 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 2, !dbg !336 + %8894 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 3, !dbg !336 + %8895 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 4, !dbg !336 + %8896 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 5, !dbg !336 + %8897 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 6, !dbg !336 + %8898 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 7, !dbg !336 + %8899 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 8, !dbg !336 + %8900 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 9, !dbg !336 + %8901 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 10, !dbg !336 + %8902 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 11, !dbg !336 + %8903 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 12, !dbg !336 + %8904 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 13, !dbg !336 + %8905 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 14, !dbg !336 + %8906 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 15, !dbg !336 + %8907 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 16, !dbg !336 + %8908 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 17, !dbg !336 + %8909 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 18, !dbg !336 + %8910 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 19, !dbg !336 + %8911 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 20, !dbg !336 + %8912 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 21, !dbg !336 + %8913 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 22, !dbg !336 + %8914 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 23, !dbg !336 + %8915 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 24, !dbg !336 + %8916 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 25, !dbg !336 + %8917 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 26, !dbg !336 + %8918 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 27, !dbg !336 + %8919 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 28, !dbg !336 + %8920 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 29, !dbg !336 + %8921 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 30, !dbg !336 + %8922 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8879, 31, !dbg !336 + %8923 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8891, float %8892, float %8893, float %8894, float %8895, float %8896, float %8897, float %8898, float %8899, float %8900, float %8901, float %8902, float %8903, float %8904, float %8905, float %8906, float %8907, float %8908, float %8909, float %8910, float %8911, float %8912, float %8913, float %8914, float %8915, float %8916, float %8917, float %8918, float %8919, float %8920, float %8921, float %8922, i64 %8885, i64 %8890, i1 true) #3, !dbg !336 + %8924 = or disjoint i32 %8868, 64, !dbg !336 + %8925 = add i32 %8924, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %8926 = lshr exact i32 %8925, 4, !dbg !336 + %8927 = and i32 %8926, 16383, !dbg !336 + %8928 = zext nneg i32 %8927 to i64, !dbg !336 + %8929 = or disjoint i64 %8928, 4611686293372403712, !dbg !336 + %8930 = add i32 %8874, 64, !dbg !336 + %8931 = lshr exact i32 %8930, 4, !dbg !336 + %8932 = and i32 %8931, 16383, !dbg !336 + %8933 = zext nneg i32 %8932 to i64, !dbg !336 + %8934 = or disjoint i64 %8933, 4611686293338849280, !dbg !336 + %8935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 0, !dbg !336 + %8936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 1, !dbg !336 + %8937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 2, !dbg !336 + %8938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 3, !dbg !336 + %8939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 4, !dbg !336 + %8940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 5, !dbg !336 + %8941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 6, !dbg !336 + %8942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 7, !dbg !336 + %8943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 8, !dbg !336 + %8944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 9, !dbg !336 + %8945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 10, !dbg !336 + %8946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 11, !dbg !336 + %8947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 12, !dbg !336 + %8948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 13, !dbg !336 + %8949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 14, !dbg !336 + %8950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 15, !dbg !336 + %8951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 16, !dbg !336 + %8952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 17, !dbg !336 + %8953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 18, !dbg !336 + %8954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 19, !dbg !336 + %8955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 20, !dbg !336 + %8956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 21, !dbg !336 + %8957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 22, !dbg !336 + %8958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 23, !dbg !336 + %8959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 24, !dbg !336 + %8960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 25, !dbg !336 + %8961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 26, !dbg !336 + %8962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 27, !dbg !336 + %8963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 28, !dbg !336 + %8964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 29, !dbg !336 + %8965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 30, !dbg !336 + %8966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8923, 31, !dbg !336 + %8967 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8935, float %8936, float %8937, float %8938, float %8939, float %8940, float %8941, float %8942, float %8943, float %8944, float %8945, float %8946, float %8947, float %8948, float %8949, float %8950, float %8951, float %8952, float %8953, float %8954, float %8955, float %8956, float %8957, float %8958, float %8959, float %8960, float %8961, float %8962, float %8963, float %8964, float %8965, float %8966, i64 %8929, i64 %8934, i1 true) #3, !dbg !336 + %8968 = or disjoint i32 %8868, 96, !dbg !336 + %8969 = add i32 %8968, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %8970 = lshr exact i32 %8969, 4, !dbg !336 + %8971 = and i32 %8970, 16383, !dbg !336 + %8972 = zext nneg i32 %8971 to i64, !dbg !336 + %8973 = or disjoint i64 %8972, 4611686293372403712, !dbg !336 + %8974 = add i32 %8874, 96, !dbg !336 + %8975 = lshr exact i32 %8974, 4, !dbg !336 + %8976 = and i32 %8975, 16383, !dbg !336 + %8977 = zext nneg i32 %8976 to i64, !dbg !336 + %8978 = or disjoint i64 %8977, 4611686293338849280, !dbg !336 + %8979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 0, !dbg !336 + %8980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 1, !dbg !336 + %8981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 2, !dbg !336 + %8982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 3, !dbg !336 + %8983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 4, !dbg !336 + %8984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 5, !dbg !336 + %8985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 6, !dbg !336 + %8986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 7, !dbg !336 + %8987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 8, !dbg !336 + %8988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 9, !dbg !336 + %8989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 10, !dbg !336 + %8990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 11, !dbg !336 + %8991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 12, !dbg !336 + %8992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 13, !dbg !336 + %8993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 14, !dbg !336 + %8994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 15, !dbg !336 + %8995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 16, !dbg !336 + %8996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 17, !dbg !336 + %8997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 18, !dbg !336 + %8998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 19, !dbg !336 + %8999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 20, !dbg !336 + %9000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 21, !dbg !336 + %9001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 22, !dbg !336 + %9002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 23, !dbg !336 + %9003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 24, !dbg !336 + %9004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 25, !dbg !336 + %9005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 26, !dbg !336 + %9006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 27, !dbg !336 + %9007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 28, !dbg !336 + %9008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 29, !dbg !336 + %9009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 30, !dbg !336 + %9010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %8967, 31, !dbg !336 + %9011 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %8979, float %8980, float %8981, float %8982, float %8983, float %8984, float %8985, float %8986, float %8987, float %8988, float %8989, float %8990, float %8991, float %8992, float %8993, float %8994, float %8995, float %8996, float %8997, float %8998, float %8999, float %9000, float %9001, float %9002, float %9003, float %9004, float %9005, float %9006, float %9007, float %9008, float %9009, float %9010, i64 %8973, i64 %8978, i1 true) #3, !dbg !336 + %9012 = or disjoint i32 %8868, 16384, !dbg !336 + %9013 = add i32 %9012, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %9014 = lshr exact i32 %9013, 4, !dbg !336 + %9015 = and i32 %9014, 16383, !dbg !336 + %9016 = zext nneg i32 %9015 to i64, !dbg !336 + %9017 = or disjoint i64 %9016, 4611686293372403712, !dbg !336 + %9018 = add i32 %8874, 8192, !dbg !336 + %9019 = lshr exact i32 %9018, 4, !dbg !336 + %9020 = and i32 %9019, 16383, !dbg !336 + %9021 = zext nneg i32 %9020 to i64, !dbg !336 + %9022 = or disjoint i64 %9021, 4611686293338849280, !dbg !336 + %9023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 0, !dbg !336 + %9024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 1, !dbg !336 + %9025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 2, !dbg !336 + %9026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 3, !dbg !336 + %9027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 4, !dbg !336 + %9028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 5, !dbg !336 + %9029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 6, !dbg !336 + %9030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 7, !dbg !336 + %9031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 8, !dbg !336 + %9032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 9, !dbg !336 + %9033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 10, !dbg !336 + %9034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 11, !dbg !336 + %9035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 12, !dbg !336 + %9036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 13, !dbg !336 + %9037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 14, !dbg !336 + %9038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 15, !dbg !336 + %9039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 16, !dbg !336 + %9040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 17, !dbg !336 + %9041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 18, !dbg !336 + %9042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 19, !dbg !336 + %9043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 20, !dbg !336 + %9044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 21, !dbg !336 + %9045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 22, !dbg !336 + %9046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 23, !dbg !336 + %9047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 24, !dbg !336 + %9048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 25, !dbg !336 + %9049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 26, !dbg !336 + %9050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 27, !dbg !336 + %9051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 28, !dbg !336 + %9052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 29, !dbg !336 + %9053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 30, !dbg !336 + %9054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9011, 31, !dbg !336 + %9055 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9023, float %9024, float %9025, float %9026, float %9027, float %9028, float %9029, float %9030, float %9031, float %9032, float %9033, float %9034, float %9035, float %9036, float %9037, float %9038, float %9039, float %9040, float %9041, float %9042, float %9043, float %9044, float %9045, float %9046, float %9047, float %9048, float %9049, float %9050, float %9051, float %9052, float %9053, float %9054, i64 %9017, i64 %9022, i1 true) #3, !dbg !336 + %9056 = or disjoint i32 %8868, 16416, !dbg !336 + %9057 = add i32 %9056, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %9058 = lshr exact i32 %9057, 4, !dbg !336 + %9059 = and i32 %9058, 16383, !dbg !336 + %9060 = zext nneg i32 %9059 to i64, !dbg !336 + %9061 = or disjoint i64 %9060, 4611686293372403712, !dbg !336 + %9062 = add i32 %8874, 8224, !dbg !336 + %9063 = lshr exact i32 %9062, 4, !dbg !336 + %9064 = and i32 %9063, 16383, !dbg !336 + %9065 = zext nneg i32 %9064 to i64, !dbg !336 + %9066 = or disjoint i64 %9065, 4611686293338849280, !dbg !336 + %9067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 0, !dbg !336 + %9068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 1, !dbg !336 + %9069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 2, !dbg !336 + %9070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 3, !dbg !336 + %9071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 4, !dbg !336 + %9072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 5, !dbg !336 + %9073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 6, !dbg !336 + %9074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 7, !dbg !336 + %9075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 8, !dbg !336 + %9076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 9, !dbg !336 + %9077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 10, !dbg !336 + %9078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 11, !dbg !336 + %9079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 12, !dbg !336 + %9080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 13, !dbg !336 + %9081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 14, !dbg !336 + %9082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 15, !dbg !336 + %9083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 16, !dbg !336 + %9084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 17, !dbg !336 + %9085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 18, !dbg !336 + %9086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 19, !dbg !336 + %9087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 20, !dbg !336 + %9088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 21, !dbg !336 + %9089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 22, !dbg !336 + %9090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 23, !dbg !336 + %9091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 24, !dbg !336 + %9092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 25, !dbg !336 + %9093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 26, !dbg !336 + %9094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 27, !dbg !336 + %9095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 28, !dbg !336 + %9096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 29, !dbg !336 + %9097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 30, !dbg !336 + %9098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9055, 31, !dbg !336 + %9099 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9067, float %9068, float %9069, float %9070, float %9071, float %9072, float %9073, float %9074, float %9075, float %9076, float %9077, float %9078, float %9079, float %9080, float %9081, float %9082, float %9083, float %9084, float %9085, float %9086, float %9087, float %9088, float %9089, float %9090, float %9091, float %9092, float %9093, float %9094, float %9095, float %9096, float %9097, float %9098, i64 %9061, i64 %9066, i1 true) #3, !dbg !336 + %9100 = or disjoint i32 %8868, 16448, !dbg !336 + %9101 = add i32 %9100, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %9102 = lshr exact i32 %9101, 4, !dbg !336 + %9103 = and i32 %9102, 16383, !dbg !336 + %9104 = zext nneg i32 %9103 to i64, !dbg !336 + %9105 = or disjoint i64 %9104, 4611686293372403712, !dbg !336 + %9106 = add i32 %8874, 8256, !dbg !336 + %9107 = lshr exact i32 %9106, 4, !dbg !336 + %9108 = and i32 %9107, 16383, !dbg !336 + %9109 = zext nneg i32 %9108 to i64, !dbg !336 + %9110 = or disjoint i64 %9109, 4611686293338849280, !dbg !336 + %9111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 0, !dbg !336 + %9112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 1, !dbg !336 + %9113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 2, !dbg !336 + %9114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 3, !dbg !336 + %9115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 4, !dbg !336 + %9116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 5, !dbg !336 + %9117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 6, !dbg !336 + %9118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 7, !dbg !336 + %9119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 8, !dbg !336 + %9120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 9, !dbg !336 + %9121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 10, !dbg !336 + %9122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 11, !dbg !336 + %9123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 12, !dbg !336 + %9124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 13, !dbg !336 + %9125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 14, !dbg !336 + %9126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 15, !dbg !336 + %9127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 16, !dbg !336 + %9128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 17, !dbg !336 + %9129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 18, !dbg !336 + %9130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 19, !dbg !336 + %9131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 20, !dbg !336 + %9132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 21, !dbg !336 + %9133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 22, !dbg !336 + %9134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 23, !dbg !336 + %9135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 24, !dbg !336 + %9136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 25, !dbg !336 + %9137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 26, !dbg !336 + %9138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 27, !dbg !336 + %9139 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 28, !dbg !336 + %9140 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 29, !dbg !336 + %9141 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 30, !dbg !336 + %9142 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9099, 31, !dbg !336 + %9143 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9111, float %9112, float %9113, float %9114, float %9115, float %9116, float %9117, float %9118, float %9119, float %9120, float %9121, float %9122, float %9123, float %9124, float %9125, float %9126, float %9127, float %9128, float %9129, float %9130, float %9131, float %9132, float %9133, float %9134, float %9135, float %9136, float %9137, float %9138, float %9139, float %9140, float %9141, float %9142, i64 %9105, i64 %9110, i1 true) #3, !dbg !336 + %9144 = or disjoint i32 %8868, 16480, !dbg !336 + %9145 = add i32 %9144, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328) to i32), !dbg !336 + %9146 = lshr exact i32 %9145, 4, !dbg !336 + %9147 = and i32 %9146, 16383, !dbg !336 + %9148 = zext nneg i32 %9147 to i64, !dbg !336 + %9149 = or disjoint i64 %9148, 4611686293372403712, !dbg !336 + %9150 = add i32 %8874, 8288, !dbg !336 + %9151 = lshr exact i32 %9150, 4, !dbg !336 + %9152 = and i32 %9151, 16383, !dbg !336 + %9153 = zext nneg i32 %9152 to i64, !dbg !336 + %9154 = or disjoint i64 %9153, 4611686293338849280, !dbg !336 + %9155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 0, !dbg !336 + %9156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 1, !dbg !336 + %9157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 2, !dbg !336 + %9158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 3, !dbg !336 + %9159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 4, !dbg !336 + %9160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 5, !dbg !336 + %9161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 6, !dbg !336 + %9162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 7, !dbg !336 + %9163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 8, !dbg !336 + %9164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 9, !dbg !336 + %9165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 10, !dbg !336 + %9166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 11, !dbg !336 + %9167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 12, !dbg !336 + %9168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 13, !dbg !336 + %9169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 14, !dbg !336 + %9170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 15, !dbg !336 + %9171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 16, !dbg !336 + %9172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 17, !dbg !336 + %9173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 18, !dbg !336 + %9174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 19, !dbg !336 + %9175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 20, !dbg !336 + %9176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 21, !dbg !336 + %9177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 22, !dbg !336 + %9178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 23, !dbg !336 + %9179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 24, !dbg !336 + %9180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 25, !dbg !336 + %9181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 26, !dbg !336 + %9182 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 27, !dbg !336 + %9183 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 28, !dbg !336 + %9184 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 29, !dbg !336 + %9185 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 30, !dbg !336 + %9186 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9143, 31, !dbg !336 + %9187 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9155, float %9156, float %9157, float %9158, float %9159, float %9160, float %9161, float %9162, float %9163, float %9164, float %9165, float %9166, float %9167, float %9168, float %9169, float %9170, float %9171, float %9172, float %9173, float %9174, float %9175, float %9176, float %9177, float %9178, float %9179, float %9180, float %9181, float %9182, float %9183, float %9184, float %9185, float %9186, i64 %9149, i64 %9154, i1 true) #3, !dbg !336 + %9188 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 0, !dbg !336 + %9189 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 1, !dbg !336 + %9190 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 2, !dbg !336 + %9191 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 3, !dbg !336 + %9192 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 4, !dbg !336 + %9193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 5, !dbg !336 + %9194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 6, !dbg !336 + %9195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 7, !dbg !336 + %9196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 8, !dbg !336 + %9197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 9, !dbg !336 + %9198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 10, !dbg !336 + %9199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 11, !dbg !336 + %9200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 12, !dbg !336 + %9201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 13, !dbg !336 + %9202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 14, !dbg !336 + %9203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 15, !dbg !336 + %9204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 16, !dbg !336 + %9205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 17, !dbg !336 + %9206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 18, !dbg !336 + %9207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 19, !dbg !336 + %9208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 20, !dbg !336 + %9209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 21, !dbg !336 + %9210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 22, !dbg !336 + %9211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 23, !dbg !336 + %9212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 24, !dbg !336 + %9213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 25, !dbg !336 + %9214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 26, !dbg !336 + %9215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 27, !dbg !336 + %9216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 28, !dbg !336 + %9217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 29, !dbg !336 + %9218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 30, !dbg !336 + %9219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9187, 31, !dbg !336 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !336 + %9220 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %9188, float %9189, float %9190, float %9191, float %9192, float %9193, float %9194, float %9195, float %9196, float %9197, float %9198, float %9199, float %9200, float %9201, float %9202, float %9203, float %9204, float %9205, float %9206, float %9207, float %9208, float %9209, float %9210, float %9211, float %9212, float %9213, float %9214, float %9215, float %9216, float %9217, float %9218, float %9219, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 99328), i32 0, i32 0, ptr addrspace(3) %8799, i32 0, i32 0) #3, !dbg !336 + %9221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 0, !dbg !336 + %9222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 1, !dbg !336 + %9223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 2, !dbg !336 + %9224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 3, !dbg !336 + %9225 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 4, !dbg !336 + %9226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 5, !dbg !336 + %9227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 6, !dbg !336 + %9228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 7, !dbg !336 + %9229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 8, !dbg !336 + %9230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 9, !dbg !336 + %9231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 10, !dbg !336 + %9232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 11, !dbg !336 + %9233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 12, !dbg !336 + %9234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 13, !dbg !336 + %9235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 14, !dbg !336 + %9236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 15, !dbg !336 + %9237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 16, !dbg !336 + %9238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 17, !dbg !336 + %9239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 18, !dbg !336 + %9240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 19, !dbg !336 + %9241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 20, !dbg !336 + %9242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 21, !dbg !336 + %9243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 22, !dbg !336 + %9244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 23, !dbg !336 + %9245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 24, !dbg !336 + %9246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 25, !dbg !336 + %9247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 26, !dbg !336 + %9248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 27, !dbg !336 + %9249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 28, !dbg !336 + %9250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 29, !dbg !336 + %9251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 30, !dbg !336 + %9252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %9220, 31, !dbg !336 + %9253 = fmul float %9221, 0x3FB6A09E60000000, !dbg !337 + %9254 = fmul float %9222, 0x3FB6A09E60000000, !dbg !337 + %9255 = fmul float %9223, 0x3FB6A09E60000000, !dbg !337 + %9256 = fmul float %9224, 0x3FB6A09E60000000, !dbg !337 + %9257 = fmul float %9225, 0x3FB6A09E60000000, !dbg !337 + %9258 = fmul float %9226, 0x3FB6A09E60000000, !dbg !337 + %9259 = fmul float %9227, 0x3FB6A09E60000000, !dbg !337 + %9260 = fmul float %9228, 0x3FB6A09E60000000, !dbg !337 + %9261 = fmul float %9229, 0x3FB6A09E60000000, !dbg !337 + %9262 = fmul float %9230, 0x3FB6A09E60000000, !dbg !337 + %9263 = fmul float %9231, 0x3FB6A09E60000000, !dbg !337 + %9264 = fmul float %9232, 0x3FB6A09E60000000, !dbg !337 + %9265 = fmul float %9233, 0x3FB6A09E60000000, !dbg !337 + %9266 = fmul float %9234, 0x3FB6A09E60000000, !dbg !337 + %9267 = fmul float %9235, 0x3FB6A09E60000000, !dbg !337 + %9268 = fmul float %9236, 0x3FB6A09E60000000, !dbg !337 + %9269 = fmul float %9237, 0x3FB6A09E60000000, !dbg !337 + %9270 = fmul float %9238, 0x3FB6A09E60000000, !dbg !337 + %9271 = fmul float %9239, 0x3FB6A09E60000000, !dbg !337 + %9272 = fmul float %9240, 0x3FB6A09E60000000, !dbg !337 + %9273 = fmul float %9241, 0x3FB6A09E60000000, !dbg !337 + %9274 = fmul float %9242, 0x3FB6A09E60000000, !dbg !337 + %9275 = fmul float %9243, 0x3FB6A09E60000000, !dbg !337 + %9276 = fmul float %9244, 0x3FB6A09E60000000, !dbg !337 + %9277 = fmul float %9245, 0x3FB6A09E60000000, !dbg !337 + %9278 = fmul float %9246, 0x3FB6A09E60000000, !dbg !337 + %9279 = fmul float %9247, 0x3FB6A09E60000000, !dbg !337 + %9280 = fmul float %9248, 0x3FB6A09E60000000, !dbg !337 + %9281 = fmul float %9249, 0x3FB6A09E60000000, !dbg !337 + %9282 = fmul float %9250, 0x3FB6A09E60000000, !dbg !337 + %9283 = fmul float %9251, 0x3FB6A09E60000000, !dbg !337 + %9284 = fmul float %9252, 0x3FB6A09E60000000, !dbg !337 + %9285 = fmul float %9253, 0x3FF7154760000000, !dbg !338 + %9286 = select i1 %8782, float %9285, float 0xFFF0000000000000, !dbg !339 + %9287 = fmul float %9254, 0x3FF7154760000000, !dbg !338 + %9288 = select i1 %8783, float %9287, float 0xFFF0000000000000, !dbg !339 + %9289 = fmul float %9255, 0x3FF7154760000000, !dbg !338 + %9290 = select i1 %8782, float %9289, float 0xFFF0000000000000, !dbg !339 + %9291 = fmul float %9256, 0x3FF7154760000000, !dbg !338 + %9292 = select i1 %8783, float %9291, float 0xFFF0000000000000, !dbg !339 + %9293 = fmul float %9257, 0x3FF7154760000000, !dbg !338 + %9294 = select i1 %8784, float %9293, float 0xFFF0000000000000, !dbg !339 + %9295 = fmul float %9258, 0x3FF7154760000000, !dbg !338 + %9296 = select i1 %8785, float %9295, float 0xFFF0000000000000, !dbg !339 + %9297 = fmul float %9259, 0x3FF7154760000000, !dbg !338 + %9298 = select i1 %8784, float %9297, float 0xFFF0000000000000, !dbg !339 + %9299 = fmul float %9260, 0x3FF7154760000000, !dbg !338 + %9300 = select i1 %8785, float %9299, float 0xFFF0000000000000, !dbg !339 + %9301 = fmul float %9261, 0x3FF7154760000000, !dbg !338 + %9302 = select i1 %8786, float %9301, float 0xFFF0000000000000, !dbg !339 + %9303 = fmul float %9262, 0x3FF7154760000000, !dbg !338 + %9304 = select i1 %8787, float %9303, float 0xFFF0000000000000, !dbg !339 + %9305 = fmul float %9263, 0x3FF7154760000000, !dbg !338 + %9306 = select i1 %8786, float %9305, float 0xFFF0000000000000, !dbg !339 + %9307 = fmul float %9264, 0x3FF7154760000000, !dbg !338 + %9308 = select i1 %8787, float %9307, float 0xFFF0000000000000, !dbg !339 + %9309 = fmul float %9265, 0x3FF7154760000000, !dbg !338 + %9310 = select i1 %8788, float %9309, float 0xFFF0000000000000, !dbg !339 + %9311 = fmul float %9266, 0x3FF7154760000000, !dbg !338 + %9312 = select i1 %8789, float %9311, float 0xFFF0000000000000, !dbg !339 + %9313 = fmul float %9267, 0x3FF7154760000000, !dbg !338 + %9314 = select i1 %8788, float %9313, float 0xFFF0000000000000, !dbg !339 + %9315 = fmul float %9268, 0x3FF7154760000000, !dbg !338 + %9316 = select i1 %8789, float %9315, float 0xFFF0000000000000, !dbg !339 + %9317 = fmul float %9269, 0x3FF7154760000000, !dbg !338 + %9318 = select i1 %8790, float %9317, float 0xFFF0000000000000, !dbg !339 + %9319 = fmul float %9270, 0x3FF7154760000000, !dbg !338 + %9320 = select i1 %8791, float %9319, float 0xFFF0000000000000, !dbg !339 + %9321 = fmul float %9271, 0x3FF7154760000000, !dbg !338 + %9322 = select i1 %8790, float %9321, float 0xFFF0000000000000, !dbg !339 + %9323 = fmul float %9272, 0x3FF7154760000000, !dbg !338 + %9324 = select i1 %8791, float %9323, float 0xFFF0000000000000, !dbg !339 + %9325 = fmul float %9273, 0x3FF7154760000000, !dbg !338 + %9326 = select i1 %8792, float %9325, float 0xFFF0000000000000, !dbg !339 + %9327 = fmul float %9274, 0x3FF7154760000000, !dbg !338 + %9328 = select i1 %8793, float %9327, float 0xFFF0000000000000, !dbg !339 + %9329 = fmul float %9275, 0x3FF7154760000000, !dbg !338 + %9330 = select i1 %8792, float %9329, float 0xFFF0000000000000, !dbg !339 + %9331 = fmul float %9276, 0x3FF7154760000000, !dbg !338 + %9332 = select i1 %8793, float %9331, float 0xFFF0000000000000, !dbg !339 + %9333 = fmul float %9277, 0x3FF7154760000000, !dbg !338 + %9334 = select i1 %8794, float %9333, float 0xFFF0000000000000, !dbg !339 + %9335 = fmul float %9278, 0x3FF7154760000000, !dbg !338 + %9336 = select i1 %8795, float %9335, float 0xFFF0000000000000, !dbg !339 + %9337 = fmul float %9279, 0x3FF7154760000000, !dbg !338 + %9338 = select i1 %8794, float %9337, float 0xFFF0000000000000, !dbg !339 + %9339 = fmul float %9280, 0x3FF7154760000000, !dbg !338 + %9340 = select i1 %8795, float %9339, float 0xFFF0000000000000, !dbg !339 + %9341 = fmul float %9281, 0x3FF7154760000000, !dbg !338 + %9342 = select i1 %8796, float %9341, float 0xFFF0000000000000, !dbg !339 + %9343 = fmul float %9282, 0x3FF7154760000000, !dbg !338 + %9344 = select i1 %8797, float %9343, float 0xFFF0000000000000, !dbg !339 + %9345 = fmul float %9283, 0x3FF7154760000000, !dbg !338 + %9346 = select i1 %8796, float %9345, float 0xFFF0000000000000, !dbg !339 + %9347 = fmul float %9284, 0x3FF7154760000000, !dbg !338 + %9348 = select i1 %8797, float %9347, float 0xFFF0000000000000, !dbg !339 + %9349 = fsub float %9286, %8850, !dbg !340 + %9350 = fsub float %9288, %8851, !dbg !340 + %9351 = fsub float %9290, %8850, !dbg !340 + %9352 = fsub float %9292, %8851, !dbg !340 + %9353 = fsub float %9294, %8852, !dbg !340 + %9354 = fsub float %9296, %8853, !dbg !340 + %9355 = fsub float %9298, %8852, !dbg !340 + %9356 = fsub float %9300, %8853, !dbg !340 + %9357 = fsub float %9302, %8854, !dbg !340 + %9358 = fsub float %9304, %8855, !dbg !340 + %9359 = fsub float %9306, %8854, !dbg !340 + %9360 = fsub float %9308, %8855, !dbg !340 + %9361 = fsub float %9310, %8856, !dbg !340 + %9362 = fsub float %9312, %8857, !dbg !340 + %9363 = fsub float %9314, %8856, !dbg !340 + %9364 = fsub float %9316, %8857, !dbg !340 + %9365 = fsub float %9318, %8858, !dbg !340 + %9366 = fsub float %9320, %8859, !dbg !340 + %9367 = fsub float %9322, %8858, !dbg !340 + %9368 = fsub float %9324, %8859, !dbg !340 + %9369 = fsub float %9326, %8860, !dbg !340 + %9370 = fsub float %9328, %8861, !dbg !340 + %9371 = fsub float %9330, %8860, !dbg !340 + %9372 = fsub float %9332, %8861, !dbg !340 + %9373 = fsub float %9334, %8862, !dbg !340 + %9374 = fsub float %9336, %8863, !dbg !340 + %9375 = fsub float %9338, %8862, !dbg !340 + %9376 = fsub float %9340, %8863, !dbg !340 + %9377 = fsub float %9342, %8864, !dbg !340 + %9378 = fsub float %9344, %8865, !dbg !340 + %9379 = fsub float %9346, %8864, !dbg !340 + %9380 = fsub float %9348, %8865, !dbg !340 + %9381 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i = icmp eq i32 %9381, 0, !dbg !341 + br i1 %.not.i, label %9384, label %9382, !dbg !341 + +9382: ; preds = %.lr.ph1873 + %9383 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9349) #3, !dbg !341 + br label %__nv_exp2f.exit, !dbg !341 + +9384: ; preds = %.lr.ph1873 + %9385 = tail call float @llvm.nvvm.ex2.approx.f(float %9349) #3, !dbg !341 + br label %__nv_exp2f.exit, !dbg !341 + +__nv_exp2f.exit: ; preds = %9382, %9384 + %.0.i = phi float [ %9383, %9382 ], [ %9385, %9384 ], !dbg !341 + %9386 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1229 = icmp eq i32 %9386, 0, !dbg !341 + br i1 %.not.i1229, label %9389, label %9387, !dbg !341 + +9387: ; preds = %__nv_exp2f.exit + %9388 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9350) #3, !dbg !341 + br label %__nv_exp2f.exit1231, !dbg !341 + +9389: ; preds = %__nv_exp2f.exit + %9390 = tail call float @llvm.nvvm.ex2.approx.f(float %9350) #3, !dbg !341 + br label %__nv_exp2f.exit1231, !dbg !341 + +__nv_exp2f.exit1231: ; preds = %9387, %9389 + %.0.i1230 = phi float [ %9388, %9387 ], [ %9390, %9389 ], !dbg !341 + %9391 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1232 = icmp eq i32 %9391, 0, !dbg !341 + br i1 %.not.i1232, label %9394, label %9392, !dbg !341 + +9392: ; preds = %__nv_exp2f.exit1231 + %9393 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9351) #3, !dbg !341 + br label %__nv_exp2f.exit1234, !dbg !341 + +9394: ; preds = %__nv_exp2f.exit1231 + %9395 = tail call float @llvm.nvvm.ex2.approx.f(float %9351) #3, !dbg !341 + br label %__nv_exp2f.exit1234, !dbg !341 + +__nv_exp2f.exit1234: ; preds = %9392, %9394 + %.0.i1233 = phi float [ %9393, %9392 ], [ %9395, %9394 ], !dbg !341 + %9396 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1235 = icmp eq i32 %9396, 0, !dbg !341 + br i1 %.not.i1235, label %9399, label %9397, !dbg !341 + +9397: ; preds = %__nv_exp2f.exit1234 + %9398 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9352) #3, !dbg !341 + br label %__nv_exp2f.exit1237, !dbg !341 + +9399: ; preds = %__nv_exp2f.exit1234 + %9400 = tail call float @llvm.nvvm.ex2.approx.f(float %9352) #3, !dbg !341 + br label %__nv_exp2f.exit1237, !dbg !341 + +__nv_exp2f.exit1237: ; preds = %9397, %9399 + %.0.i1236 = phi float [ %9398, %9397 ], [ %9400, %9399 ], !dbg !341 + %9401 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1238 = icmp eq i32 %9401, 0, !dbg !341 + br i1 %.not.i1238, label %9404, label %9402, !dbg !341 + +9402: ; preds = %__nv_exp2f.exit1237 + %9403 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9353) #3, !dbg !341 + br label %__nv_exp2f.exit1240, !dbg !341 + +9404: ; preds = %__nv_exp2f.exit1237 + %9405 = tail call float @llvm.nvvm.ex2.approx.f(float %9353) #3, !dbg !341 + br label %__nv_exp2f.exit1240, !dbg !341 + +__nv_exp2f.exit1240: ; preds = %9402, %9404 + %.0.i1239 = phi float [ %9403, %9402 ], [ %9405, %9404 ], !dbg !341 + %9406 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1241 = icmp eq i32 %9406, 0, !dbg !341 + br i1 %.not.i1241, label %9409, label %9407, !dbg !341 + +9407: ; preds = %__nv_exp2f.exit1240 + %9408 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9354) #3, !dbg !341 + br label %__nv_exp2f.exit1243, !dbg !341 + +9409: ; preds = %__nv_exp2f.exit1240 + %9410 = tail call float @llvm.nvvm.ex2.approx.f(float %9354) #3, !dbg !341 + br label %__nv_exp2f.exit1243, !dbg !341 + +__nv_exp2f.exit1243: ; preds = %9407, %9409 + %.0.i1242 = phi float [ %9408, %9407 ], [ %9410, %9409 ], !dbg !341 + %9411 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1244 = icmp eq i32 %9411, 0, !dbg !341 + br i1 %.not.i1244, label %9414, label %9412, !dbg !341 + +9412: ; preds = %__nv_exp2f.exit1243 + %9413 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9355) #3, !dbg !341 + br label %__nv_exp2f.exit1246, !dbg !341 + +9414: ; preds = %__nv_exp2f.exit1243 + %9415 = tail call float @llvm.nvvm.ex2.approx.f(float %9355) #3, !dbg !341 + br label %__nv_exp2f.exit1246, !dbg !341 + +__nv_exp2f.exit1246: ; preds = %9412, %9414 + %.0.i1245 = phi float [ %9413, %9412 ], [ %9415, %9414 ], !dbg !341 + %9416 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1247 = icmp eq i32 %9416, 0, !dbg !341 + br i1 %.not.i1247, label %9419, label %9417, !dbg !341 + +9417: ; preds = %__nv_exp2f.exit1246 + %9418 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9356) #3, !dbg !341 + br label %__nv_exp2f.exit1249, !dbg !341 + +9419: ; preds = %__nv_exp2f.exit1246 + %9420 = tail call float @llvm.nvvm.ex2.approx.f(float %9356) #3, !dbg !341 + br label %__nv_exp2f.exit1249, !dbg !341 + +__nv_exp2f.exit1249: ; preds = %9417, %9419 + %.0.i1248 = phi float [ %9418, %9417 ], [ %9420, %9419 ], !dbg !341 + %9421 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1250 = icmp eq i32 %9421, 0, !dbg !341 + br i1 %.not.i1250, label %9424, label %9422, !dbg !341 + +9422: ; preds = %__nv_exp2f.exit1249 + %9423 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9357) #3, !dbg !341 + br label %__nv_exp2f.exit1252, !dbg !341 + +9424: ; preds = %__nv_exp2f.exit1249 + %9425 = tail call float @llvm.nvvm.ex2.approx.f(float %9357) #3, !dbg !341 + br label %__nv_exp2f.exit1252, !dbg !341 + +__nv_exp2f.exit1252: ; preds = %9422, %9424 + %.0.i1251 = phi float [ %9423, %9422 ], [ %9425, %9424 ], !dbg !341 + %9426 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1253 = icmp eq i32 %9426, 0, !dbg !341 + br i1 %.not.i1253, label %9429, label %9427, !dbg !341 + +9427: ; preds = %__nv_exp2f.exit1252 + %9428 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9358) #3, !dbg !341 + br label %__nv_exp2f.exit1255, !dbg !341 + +9429: ; preds = %__nv_exp2f.exit1252 + %9430 = tail call float @llvm.nvvm.ex2.approx.f(float %9358) #3, !dbg !341 + br label %__nv_exp2f.exit1255, !dbg !341 + +__nv_exp2f.exit1255: ; preds = %9427, %9429 + %.0.i1254 = phi float [ %9428, %9427 ], [ %9430, %9429 ], !dbg !341 + %9431 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1256 = icmp eq i32 %9431, 0, !dbg !341 + br i1 %.not.i1256, label %9434, label %9432, !dbg !341 + +9432: ; preds = %__nv_exp2f.exit1255 + %9433 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9359) #3, !dbg !341 + br label %__nv_exp2f.exit1258, !dbg !341 + +9434: ; preds = %__nv_exp2f.exit1255 + %9435 = tail call float @llvm.nvvm.ex2.approx.f(float %9359) #3, !dbg !341 + br label %__nv_exp2f.exit1258, !dbg !341 + +__nv_exp2f.exit1258: ; preds = %9432, %9434 + %.0.i1257 = phi float [ %9433, %9432 ], [ %9435, %9434 ], !dbg !341 + %9436 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1259 = icmp eq i32 %9436, 0, !dbg !341 + br i1 %.not.i1259, label %9439, label %9437, !dbg !341 + +9437: ; preds = %__nv_exp2f.exit1258 + %9438 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9360) #3, !dbg !341 + br label %__nv_exp2f.exit1261, !dbg !341 + +9439: ; preds = %__nv_exp2f.exit1258 + %9440 = tail call float @llvm.nvvm.ex2.approx.f(float %9360) #3, !dbg !341 + br label %__nv_exp2f.exit1261, !dbg !341 + +__nv_exp2f.exit1261: ; preds = %9437, %9439 + %.0.i1260 = phi float [ %9438, %9437 ], [ %9440, %9439 ], !dbg !341 + %9441 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1262 = icmp eq i32 %9441, 0, !dbg !341 + br i1 %.not.i1262, label %9444, label %9442, !dbg !341 + +9442: ; preds = %__nv_exp2f.exit1261 + %9443 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9361) #3, !dbg !341 + br label %__nv_exp2f.exit1264, !dbg !341 + +9444: ; preds = %__nv_exp2f.exit1261 + %9445 = tail call float @llvm.nvvm.ex2.approx.f(float %9361) #3, !dbg !341 + br label %__nv_exp2f.exit1264, !dbg !341 + +__nv_exp2f.exit1264: ; preds = %9442, %9444 + %.0.i1263 = phi float [ %9443, %9442 ], [ %9445, %9444 ], !dbg !341 + %9446 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1265 = icmp eq i32 %9446, 0, !dbg !341 + br i1 %.not.i1265, label %9449, label %9447, !dbg !341 + +9447: ; preds = %__nv_exp2f.exit1264 + %9448 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9362) #3, !dbg !341 + br label %__nv_exp2f.exit1267, !dbg !341 + +9449: ; preds = %__nv_exp2f.exit1264 + %9450 = tail call float @llvm.nvvm.ex2.approx.f(float %9362) #3, !dbg !341 + br label %__nv_exp2f.exit1267, !dbg !341 + +__nv_exp2f.exit1267: ; preds = %9447, %9449 + %.0.i1266 = phi float [ %9448, %9447 ], [ %9450, %9449 ], !dbg !341 + %9451 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1268 = icmp eq i32 %9451, 0, !dbg !341 + br i1 %.not.i1268, label %9454, label %9452, !dbg !341 + +9452: ; preds = %__nv_exp2f.exit1267 + %9453 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9363) #3, !dbg !341 + br label %__nv_exp2f.exit1270, !dbg !341 + +9454: ; preds = %__nv_exp2f.exit1267 + %9455 = tail call float @llvm.nvvm.ex2.approx.f(float %9363) #3, !dbg !341 + br label %__nv_exp2f.exit1270, !dbg !341 + +__nv_exp2f.exit1270: ; preds = %9452, %9454 + %.0.i1269 = phi float [ %9453, %9452 ], [ %9455, %9454 ], !dbg !341 + %9456 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1271 = icmp eq i32 %9456, 0, !dbg !341 + br i1 %.not.i1271, label %9459, label %9457, !dbg !341 + +9457: ; preds = %__nv_exp2f.exit1270 + %9458 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9364) #3, !dbg !341 + br label %__nv_exp2f.exit1273, !dbg !341 + +9459: ; preds = %__nv_exp2f.exit1270 + %9460 = tail call float @llvm.nvvm.ex2.approx.f(float %9364) #3, !dbg !341 + br label %__nv_exp2f.exit1273, !dbg !341 + +__nv_exp2f.exit1273: ; preds = %9457, %9459 + %.0.i1272 = phi float [ %9458, %9457 ], [ %9460, %9459 ], !dbg !341 + %9461 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1274 = icmp eq i32 %9461, 0, !dbg !341 + br i1 %.not.i1274, label %9464, label %9462, !dbg !341 + +9462: ; preds = %__nv_exp2f.exit1273 + %9463 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9365) #3, !dbg !341 + br label %__nv_exp2f.exit1276, !dbg !341 + +9464: ; preds = %__nv_exp2f.exit1273 + %9465 = tail call float @llvm.nvvm.ex2.approx.f(float %9365) #3, !dbg !341 + br label %__nv_exp2f.exit1276, !dbg !341 + +__nv_exp2f.exit1276: ; preds = %9462, %9464 + %.0.i1275 = phi float [ %9463, %9462 ], [ %9465, %9464 ], !dbg !341 + %9466 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1277 = icmp eq i32 %9466, 0, !dbg !341 + br i1 %.not.i1277, label %9469, label %9467, !dbg !341 + +9467: ; preds = %__nv_exp2f.exit1276 + %9468 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9366) #3, !dbg !341 + br label %__nv_exp2f.exit1279, !dbg !341 + +9469: ; preds = %__nv_exp2f.exit1276 + %9470 = tail call float @llvm.nvvm.ex2.approx.f(float %9366) #3, !dbg !341 + br label %__nv_exp2f.exit1279, !dbg !341 + +__nv_exp2f.exit1279: ; preds = %9467, %9469 + %.0.i1278 = phi float [ %9468, %9467 ], [ %9470, %9469 ], !dbg !341 + %9471 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1280 = icmp eq i32 %9471, 0, !dbg !341 + br i1 %.not.i1280, label %9474, label %9472, !dbg !341 + +9472: ; preds = %__nv_exp2f.exit1279 + %9473 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9367) #3, !dbg !341 + br label %__nv_exp2f.exit1282, !dbg !341 + +9474: ; preds = %__nv_exp2f.exit1279 + %9475 = tail call float @llvm.nvvm.ex2.approx.f(float %9367) #3, !dbg !341 + br label %__nv_exp2f.exit1282, !dbg !341 + +__nv_exp2f.exit1282: ; preds = %9472, %9474 + %.0.i1281 = phi float [ %9473, %9472 ], [ %9475, %9474 ], !dbg !341 + %9476 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1283 = icmp eq i32 %9476, 0, !dbg !341 + br i1 %.not.i1283, label %9479, label %9477, !dbg !341 + +9477: ; preds = %__nv_exp2f.exit1282 + %9478 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9368) #3, !dbg !341 + br label %__nv_exp2f.exit1285, !dbg !341 + +9479: ; preds = %__nv_exp2f.exit1282 + %9480 = tail call float @llvm.nvvm.ex2.approx.f(float %9368) #3, !dbg !341 + br label %__nv_exp2f.exit1285, !dbg !341 + +__nv_exp2f.exit1285: ; preds = %9477, %9479 + %.0.i1284 = phi float [ %9478, %9477 ], [ %9480, %9479 ], !dbg !341 + %9481 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1286 = icmp eq i32 %9481, 0, !dbg !341 + br i1 %.not.i1286, label %9484, label %9482, !dbg !341 + +9482: ; preds = %__nv_exp2f.exit1285 + %9483 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9369) #3, !dbg !341 + br label %__nv_exp2f.exit1288, !dbg !341 + +9484: ; preds = %__nv_exp2f.exit1285 + %9485 = tail call float @llvm.nvvm.ex2.approx.f(float %9369) #3, !dbg !341 + br label %__nv_exp2f.exit1288, !dbg !341 + +__nv_exp2f.exit1288: ; preds = %9482, %9484 + %.0.i1287 = phi float [ %9483, %9482 ], [ %9485, %9484 ], !dbg !341 + %9486 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1289 = icmp eq i32 %9486, 0, !dbg !341 + br i1 %.not.i1289, label %9489, label %9487, !dbg !341 + +9487: ; preds = %__nv_exp2f.exit1288 + %9488 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9370) #3, !dbg !341 + br label %__nv_exp2f.exit1291, !dbg !341 + +9489: ; preds = %__nv_exp2f.exit1288 + %9490 = tail call float @llvm.nvvm.ex2.approx.f(float %9370) #3, !dbg !341 + br label %__nv_exp2f.exit1291, !dbg !341 + +__nv_exp2f.exit1291: ; preds = %9487, %9489 + %.0.i1290 = phi float [ %9488, %9487 ], [ %9490, %9489 ], !dbg !341 + %9491 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1292 = icmp eq i32 %9491, 0, !dbg !341 + br i1 %.not.i1292, label %9494, label %9492, !dbg !341 + +9492: ; preds = %__nv_exp2f.exit1291 + %9493 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9371) #3, !dbg !341 + br label %__nv_exp2f.exit1294, !dbg !341 + +9494: ; preds = %__nv_exp2f.exit1291 + %9495 = tail call float @llvm.nvvm.ex2.approx.f(float %9371) #3, !dbg !341 + br label %__nv_exp2f.exit1294, !dbg !341 + +__nv_exp2f.exit1294: ; preds = %9492, %9494 + %.0.i1293 = phi float [ %9493, %9492 ], [ %9495, %9494 ], !dbg !341 + %9496 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1295 = icmp eq i32 %9496, 0, !dbg !341 + br i1 %.not.i1295, label %9499, label %9497, !dbg !341 + +9497: ; preds = %__nv_exp2f.exit1294 + %9498 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9372) #3, !dbg !341 + br label %__nv_exp2f.exit1297, !dbg !341 + +9499: ; preds = %__nv_exp2f.exit1294 + %9500 = tail call float @llvm.nvvm.ex2.approx.f(float %9372) #3, !dbg !341 + br label %__nv_exp2f.exit1297, !dbg !341 + +__nv_exp2f.exit1297: ; preds = %9497, %9499 + %.0.i1296 = phi float [ %9498, %9497 ], [ %9500, %9499 ], !dbg !341 + %9501 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1298 = icmp eq i32 %9501, 0, !dbg !341 + br i1 %.not.i1298, label %9504, label %9502, !dbg !341 + +9502: ; preds = %__nv_exp2f.exit1297 + %9503 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9373) #3, !dbg !341 + br label %__nv_exp2f.exit1300, !dbg !341 + +9504: ; preds = %__nv_exp2f.exit1297 + %9505 = tail call float @llvm.nvvm.ex2.approx.f(float %9373) #3, !dbg !341 + br label %__nv_exp2f.exit1300, !dbg !341 + +__nv_exp2f.exit1300: ; preds = %9502, %9504 + %.0.i1299 = phi float [ %9503, %9502 ], [ %9505, %9504 ], !dbg !341 + %9506 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1301 = icmp eq i32 %9506, 0, !dbg !341 + br i1 %.not.i1301, label %9509, label %9507, !dbg !341 + +9507: ; preds = %__nv_exp2f.exit1300 + %9508 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9374) #3, !dbg !341 + br label %__nv_exp2f.exit1303, !dbg !341 + +9509: ; preds = %__nv_exp2f.exit1300 + %9510 = tail call float @llvm.nvvm.ex2.approx.f(float %9374) #3, !dbg !341 + br label %__nv_exp2f.exit1303, !dbg !341 + +__nv_exp2f.exit1303: ; preds = %9507, %9509 + %.0.i1302 = phi float [ %9508, %9507 ], [ %9510, %9509 ], !dbg !341 + %9511 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1304 = icmp eq i32 %9511, 0, !dbg !341 + br i1 %.not.i1304, label %9514, label %9512, !dbg !341 + +9512: ; preds = %__nv_exp2f.exit1303 + %9513 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9375) #3, !dbg !341 + br label %__nv_exp2f.exit1306, !dbg !341 + +9514: ; preds = %__nv_exp2f.exit1303 + %9515 = tail call float @llvm.nvvm.ex2.approx.f(float %9375) #3, !dbg !341 + br label %__nv_exp2f.exit1306, !dbg !341 + +__nv_exp2f.exit1306: ; preds = %9512, %9514 + %.0.i1305 = phi float [ %9513, %9512 ], [ %9515, %9514 ], !dbg !341 + %9516 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1307 = icmp eq i32 %9516, 0, !dbg !341 + br i1 %.not.i1307, label %9519, label %9517, !dbg !341 + +9517: ; preds = %__nv_exp2f.exit1306 + %9518 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9376) #3, !dbg !341 + br label %__nv_exp2f.exit1309, !dbg !341 + +9519: ; preds = %__nv_exp2f.exit1306 + %9520 = tail call float @llvm.nvvm.ex2.approx.f(float %9376) #3, !dbg !341 + br label %__nv_exp2f.exit1309, !dbg !341 + +__nv_exp2f.exit1309: ; preds = %9517, %9519 + %.0.i1308 = phi float [ %9518, %9517 ], [ %9520, %9519 ], !dbg !341 + %9521 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1310 = icmp eq i32 %9521, 0, !dbg !341 + br i1 %.not.i1310, label %9524, label %9522, !dbg !341 + +9522: ; preds = %__nv_exp2f.exit1309 + %9523 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9377) #3, !dbg !341 + br label %__nv_exp2f.exit1312, !dbg !341 + +9524: ; preds = %__nv_exp2f.exit1309 + %9525 = tail call float @llvm.nvvm.ex2.approx.f(float %9377) #3, !dbg !341 + br label %__nv_exp2f.exit1312, !dbg !341 + +__nv_exp2f.exit1312: ; preds = %9522, %9524 + %.0.i1311 = phi float [ %9523, %9522 ], [ %9525, %9524 ], !dbg !341 + %9526 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1313 = icmp eq i32 %9526, 0, !dbg !341 + br i1 %.not.i1313, label %9529, label %9527, !dbg !341 + +9527: ; preds = %__nv_exp2f.exit1312 + %9528 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9378) #3, !dbg !341 + br label %__nv_exp2f.exit1315, !dbg !341 + +9529: ; preds = %__nv_exp2f.exit1312 + %9530 = tail call float @llvm.nvvm.ex2.approx.f(float %9378) #3, !dbg !341 + br label %__nv_exp2f.exit1315, !dbg !341 + +__nv_exp2f.exit1315: ; preds = %9527, %9529 + %.0.i1314 = phi float [ %9528, %9527 ], [ %9530, %9529 ], !dbg !341 + %9531 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1316 = icmp eq i32 %9531, 0, !dbg !341 + br i1 %.not.i1316, label %9534, label %9532, !dbg !341 + +9532: ; preds = %__nv_exp2f.exit1315 + %9533 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9379) #3, !dbg !341 + br label %__nv_exp2f.exit1318, !dbg !341 + +9534: ; preds = %__nv_exp2f.exit1315 + %9535 = tail call float @llvm.nvvm.ex2.approx.f(float %9379) #3, !dbg !341 + br label %__nv_exp2f.exit1318, !dbg !341 + +__nv_exp2f.exit1318: ; preds = %9532, %9534 + %.0.i1317 = phi float [ %9533, %9532 ], [ %9535, %9534 ], !dbg !341 + %9536 = tail call i32 @__nvvm_reflect(ptr nonnull @.str) #3, !dbg !341 + %.not.i1319 = icmp eq i32 %9536, 0, !dbg !341 + br i1 %.not.i1319, label %9539, label %9537, !dbg !341 + +9537: ; preds = %__nv_exp2f.exit1318 + %9538 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %9380) #3, !dbg !341 + br label %__nv_exp2f.exit1321, !dbg !341 + +9539: ; preds = %__nv_exp2f.exit1318 + %9540 = tail call float @llvm.nvvm.ex2.approx.f(float %9380) #3, !dbg !341 + br label %__nv_exp2f.exit1321, !dbg !341 + +__nv_exp2f.exit1321: ; preds = %9537, %9539 + %.0.i1320 = phi float [ %9538, %9537 ], [ %9540, %9539 ], !dbg !341 + %9541 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %8798, !dbg !327 + %9542 = insertelement <2 x float> poison, float %.0.i, i64 0, !dbg !342 + %9543 = insertelement <2 x float> %9542, float %.0.i1230, i64 1, !dbg !342 + %9544 = fptrunc <2 x float> %9543 to <2 x bfloat>, !dbg !342 + %9545 = insertelement <2 x float> poison, float %.0.i1233, i64 0, !dbg !342 + %9546 = insertelement <2 x float> %9545, float %.0.i1236, i64 1, !dbg !342 + %9547 = fptrunc <2 x float> %9546 to <2 x bfloat>, !dbg !342 + %9548 = insertelement <2 x float> poison, float %.0.i1239, i64 0, !dbg !342 + %9549 = insertelement <2 x float> %9548, float %.0.i1242, i64 1, !dbg !342 + %9550 = fptrunc <2 x float> %9549 to <2 x bfloat>, !dbg !342 + %9551 = insertelement <2 x float> poison, float %.0.i1245, i64 0, !dbg !342 + %9552 = insertelement <2 x float> %9551, float %.0.i1248, i64 1, !dbg !342 + %9553 = fptrunc <2 x float> %9552 to <2 x bfloat>, !dbg !342 + %9554 = insertelement <2 x float> poison, float %.0.i1251, i64 0, !dbg !342 + %9555 = insertelement <2 x float> %9554, float %.0.i1254, i64 1, !dbg !342 + %9556 = fptrunc <2 x float> %9555 to <2 x bfloat>, !dbg !342 + %9557 = insertelement <2 x float> poison, float %.0.i1257, i64 0, !dbg !342 + %9558 = insertelement <2 x float> %9557, float %.0.i1260, i64 1, !dbg !342 + %9559 = fptrunc <2 x float> %9558 to <2 x bfloat>, !dbg !342 + %9560 = insertelement <2 x float> poison, float %.0.i1263, i64 0, !dbg !342 + %9561 = insertelement <2 x float> %9560, float %.0.i1266, i64 1, !dbg !342 + %9562 = fptrunc <2 x float> %9561 to <2 x bfloat>, !dbg !342 + %9563 = insertelement <2 x float> poison, float %.0.i1269, i64 0, !dbg !342 + %9564 = insertelement <2 x float> %9563, float %.0.i1272, i64 1, !dbg !342 + %9565 = fptrunc <2 x float> %9564 to <2 x bfloat>, !dbg !342 + %9566 = insertelement <2 x float> poison, float %.0.i1275, i64 0, !dbg !342 + %9567 = insertelement <2 x float> %9566, float %.0.i1278, i64 1, !dbg !342 + %9568 = fptrunc <2 x float> %9567 to <2 x bfloat>, !dbg !342 + %9569 = insertelement <2 x float> poison, float %.0.i1281, i64 0, !dbg !342 + %9570 = insertelement <2 x float> %9569, float %.0.i1284, i64 1, !dbg !342 + %9571 = fptrunc <2 x float> %9570 to <2 x bfloat>, !dbg !342 + %9572 = insertelement <2 x float> poison, float %.0.i1287, i64 0, !dbg !342 + %9573 = insertelement <2 x float> %9572, float %.0.i1290, i64 1, !dbg !342 + %9574 = fptrunc <2 x float> %9573 to <2 x bfloat>, !dbg !342 + %9575 = insertelement <2 x float> poison, float %.0.i1293, i64 0, !dbg !342 + %9576 = insertelement <2 x float> %9575, float %.0.i1296, i64 1, !dbg !342 + %9577 = fptrunc <2 x float> %9576 to <2 x bfloat>, !dbg !342 + %9578 = insertelement <2 x float> poison, float %.0.i1299, i64 0, !dbg !342 + %9579 = insertelement <2 x float> %9578, float %.0.i1302, i64 1, !dbg !342 + %9580 = fptrunc <2 x float> %9579 to <2 x bfloat>, !dbg !342 + %9581 = insertelement <2 x float> poison, float %.0.i1305, i64 0, !dbg !342 + %9582 = insertelement <2 x float> %9581, float %.0.i1308, i64 1, !dbg !342 + %9583 = fptrunc <2 x float> %9582 to <2 x bfloat>, !dbg !342 + %9584 = insertelement <2 x float> poison, float %.0.i1311, i64 0, !dbg !342 + %9585 = insertelement <2 x float> %9584, float %.0.i1314, i64 1, !dbg !342 + %9586 = fptrunc <2 x float> %9585 to <2 x bfloat>, !dbg !342 + %9587 = insertelement <2 x float> poison, float %.0.i1317, i64 0, !dbg !342 + %9588 = insertelement <2 x float> %9587, float %.0.i1320, i64 1, !dbg !342 + %9589 = fptrunc <2 x float> %9588 to <2 x bfloat>, !dbg !342 + %9590 = bitcast <2 x bfloat> %9544 to i32, !dbg !343 + %9591 = bitcast <2 x bfloat> %9547 to i32, !dbg !343 + %9592 = bitcast <2 x bfloat> %9550 to i32, !dbg !343 + %9593 = bitcast <2 x bfloat> %9553 to i32, !dbg !343 + %9594 = bitcast <2 x bfloat> %9556 to i32, !dbg !343 + %9595 = bitcast <2 x bfloat> %9559 to i32, !dbg !343 + %9596 = bitcast <2 x bfloat> %9562 to i32, !dbg !343 + %9597 = bitcast <2 x bfloat> %9565 to i32, !dbg !343 + %9598 = bitcast <2 x bfloat> %9568 to i32, !dbg !343 + %9599 = bitcast <2 x bfloat> %9571 to i32, !dbg !343 + %9600 = bitcast <2 x bfloat> %9574 to i32, !dbg !343 + %9601 = bitcast <2 x bfloat> %9577 to i32, !dbg !343 + %9602 = bitcast <2 x bfloat> %9580 to i32, !dbg !343 + %9603 = bitcast <2 x bfloat> %9583 to i32, !dbg !343 + %9604 = bitcast <2 x bfloat> %9586 to i32, !dbg !343 + %9605 = bitcast <2 x bfloat> %9589 to i32, !dbg !343 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !343 + %9606 = ptrtoint ptr addrspace(3) %9541 to i32, !dbg !343 + %9607 = lshr exact i32 %9606, 4, !dbg !343 + %9608 = and i32 %9607, 16383, !dbg !343 + %9609 = zext nneg i32 %9608 to i64, !dbg !343 + %9610 = or disjoint i64 %9609, 4611686293338849280, !dbg !343 + %9611 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn4751768, float %.pn4731769, float %.pn4711770, float %.pn4691771, float %.pn4671772, float %.pn4651773, float %.pn4631774, float %.pn4611775, float %.pn4591776, float %.pn4571777, float %.pn4551778, float %.pn4531779, float %.pn4511780, float %.pn4491781, float %.pn4471782, float %.pn4451783, float %.pn4431784, float %.pn4411785, float %.pn4391786, float %.pn4371787, float %.pn4351788, float %.pn4331789, float %.pn4311790, float %.pn4291791, float %.pn4271792, float %.pn4251793, float %.pn4231794, float %.pn4211795, float %.pn4191796, float %.pn4171797, float %.pn4151798, float %.pn4131799, float %.pn4111800, float %.pn4091801, float %.pn4071802, float %.pn4051803, float %.pn4031804, float %.pn4011805, float %.pn3991806, float %.pn3971807, float %.pn3951808, float %.pn3931809, float %.pn3911810, float %.pn3891811, float %.pn3871812, float %.pn3851813, float %.pn3831814, float %.pn3811815, float %.pn3791816, float %.pn3771817, float %.pn3751818, float %.pn3731819, float %.pn3711820, float %.pn3691821, float %.pn3671822, float %.pn3651823, float %.pn3631824, float %.pn3611825, float %.pn3591826, float %.pn3571827, float %.pn3551828, float %.pn3531829, float %.pn3511830, float %.pn3491831, i32 %9590, i32 %9591, i32 %9592, i32 %9593, i64 %9610, i1 true) #3, !dbg !343 + %9612 = add i32 %9606, 2048, !dbg !343 + %9613 = lshr exact i32 %9612, 4, !dbg !343 + %9614 = and i32 %9613, 16383, !dbg !343 + %9615 = zext nneg i32 %9614 to i64, !dbg !343 + %9616 = or disjoint i64 %9615, 4611686293338849280, !dbg !343 + %9617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 0, !dbg !343 + %9618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 1, !dbg !343 + %9619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 2, !dbg !343 + %9620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 3, !dbg !343 + %9621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 4, !dbg !343 + %9622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 5, !dbg !343 + %9623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 6, !dbg !343 + %9624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 7, !dbg !343 + %9625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 8, !dbg !343 + %9626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 9, !dbg !343 + %9627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 10, !dbg !343 + %9628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 11, !dbg !343 + %9629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 12, !dbg !343 + %9630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 13, !dbg !343 + %9631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 14, !dbg !343 + %9632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 15, !dbg !343 + %9633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 16, !dbg !343 + %9634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 17, !dbg !343 + %9635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 18, !dbg !343 + %9636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 19, !dbg !343 + %9637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 20, !dbg !343 + %9638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 21, !dbg !343 + %9639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 22, !dbg !343 + %9640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 23, !dbg !343 + %9641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 24, !dbg !343 + %9642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 25, !dbg !343 + %9643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 26, !dbg !343 + %9644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 27, !dbg !343 + %9645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 28, !dbg !343 + %9646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 29, !dbg !343 + %9647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 30, !dbg !343 + %9648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 31, !dbg !343 + %9649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 32, !dbg !343 + %9650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 33, !dbg !343 + %9651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 34, !dbg !343 + %9652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 35, !dbg !343 + %9653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 36, !dbg !343 + %9654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 37, !dbg !343 + %9655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 38, !dbg !343 + %9656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 39, !dbg !343 + %9657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 40, !dbg !343 + %9658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 41, !dbg !343 + %9659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 42, !dbg !343 + %9660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 43, !dbg !343 + %9661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 44, !dbg !343 + %9662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 45, !dbg !343 + %9663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 46, !dbg !343 + %9664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 47, !dbg !343 + %9665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 48, !dbg !343 + %9666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 49, !dbg !343 + %9667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 50, !dbg !343 + %9668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 51, !dbg !343 + %9669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 52, !dbg !343 + %9670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 53, !dbg !343 + %9671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 54, !dbg !343 + %9672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 55, !dbg !343 + %9673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 56, !dbg !343 + %9674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 57, !dbg !343 + %9675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 58, !dbg !343 + %9676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 59, !dbg !343 + %9677 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 60, !dbg !343 + %9678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 61, !dbg !343 + %9679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 62, !dbg !343 + %9680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9611, 63, !dbg !343 + %9681 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9617, float %9618, float %9619, float %9620, float %9621, float %9622, float %9623, float %9624, float %9625, float %9626, float %9627, float %9628, float %9629, float %9630, float %9631, float %9632, float %9633, float %9634, float %9635, float %9636, float %9637, float %9638, float %9639, float %9640, float %9641, float %9642, float %9643, float %9644, float %9645, float %9646, float %9647, float %9648, float %9649, float %9650, float %9651, float %9652, float %9653, float %9654, float %9655, float %9656, float %9657, float %9658, float %9659, float %9660, float %9661, float %9662, float %9663, float %9664, float %9665, float %9666, float %9667, float %9668, float %9669, float %9670, float %9671, float %9672, float %9673, float %9674, float %9675, float %9676, float %9677, float %9678, float %9679, float %9680, i32 %9594, i32 %9595, i32 %9596, i32 %9597, i64 %9616, i1 true) #3, !dbg !343 + %9682 = add i32 %9606, 4096, !dbg !343 + %9683 = lshr exact i32 %9682, 4, !dbg !343 + %9684 = and i32 %9683, 16383, !dbg !343 + %9685 = zext nneg i32 %9684 to i64, !dbg !343 + %9686 = or disjoint i64 %9685, 4611686293338849280, !dbg !343 + %9687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 0, !dbg !343 + %9688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 1, !dbg !343 + %9689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 2, !dbg !343 + %9690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 3, !dbg !343 + %9691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 4, !dbg !343 + %9692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 5, !dbg !343 + %9693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 6, !dbg !343 + %9694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 7, !dbg !343 + %9695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 8, !dbg !343 + %9696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 9, !dbg !343 + %9697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 10, !dbg !343 + %9698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 11, !dbg !343 + %9699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 12, !dbg !343 + %9700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 13, !dbg !343 + %9701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 14, !dbg !343 + %9702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 15, !dbg !343 + %9703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 16, !dbg !343 + %9704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 17, !dbg !343 + %9705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 18, !dbg !343 + %9706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 19, !dbg !343 + %9707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 20, !dbg !343 + %9708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 21, !dbg !343 + %9709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 22, !dbg !343 + %9710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 23, !dbg !343 + %9711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 24, !dbg !343 + %9712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 25, !dbg !343 + %9713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 26, !dbg !343 + %9714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 27, !dbg !343 + %9715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 28, !dbg !343 + %9716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 29, !dbg !343 + %9717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 30, !dbg !343 + %9718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 31, !dbg !343 + %9719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 32, !dbg !343 + %9720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 33, !dbg !343 + %9721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 34, !dbg !343 + %9722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 35, !dbg !343 + %9723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 36, !dbg !343 + %9724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 37, !dbg !343 + %9725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 38, !dbg !343 + %9726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 39, !dbg !343 + %9727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 40, !dbg !343 + %9728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 41, !dbg !343 + %9729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 42, !dbg !343 + %9730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 43, !dbg !343 + %9731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 44, !dbg !343 + %9732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 45, !dbg !343 + %9733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 46, !dbg !343 + %9734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 47, !dbg !343 + %9735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 48, !dbg !343 + %9736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 49, !dbg !343 + %9737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 50, !dbg !343 + %9738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 51, !dbg !343 + %9739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 52, !dbg !343 + %9740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 53, !dbg !343 + %9741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 54, !dbg !343 + %9742 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 55, !dbg !343 + %9743 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 56, !dbg !343 + %9744 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 57, !dbg !343 + %9745 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 58, !dbg !343 + %9746 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 59, !dbg !343 + %9747 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 60, !dbg !343 + %9748 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 61, !dbg !343 + %9749 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 62, !dbg !343 + %9750 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9681, 63, !dbg !343 + %9751 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9687, float %9688, float %9689, float %9690, float %9691, float %9692, float %9693, float %9694, float %9695, float %9696, float %9697, float %9698, float %9699, float %9700, float %9701, float %9702, float %9703, float %9704, float %9705, float %9706, float %9707, float %9708, float %9709, float %9710, float %9711, float %9712, float %9713, float %9714, float %9715, float %9716, float %9717, float %9718, float %9719, float %9720, float %9721, float %9722, float %9723, float %9724, float %9725, float %9726, float %9727, float %9728, float %9729, float %9730, float %9731, float %9732, float %9733, float %9734, float %9735, float %9736, float %9737, float %9738, float %9739, float %9740, float %9741, float %9742, float %9743, float %9744, float %9745, float %9746, float %9747, float %9748, float %9749, float %9750, i32 %9598, i32 %9599, i32 %9600, i32 %9601, i64 %9686, i1 true) #3, !dbg !343 + %9752 = add i32 %9606, 6144, !dbg !343 + %9753 = lshr exact i32 %9752, 4, !dbg !343 + %9754 = and i32 %9753, 16383, !dbg !343 + %9755 = zext nneg i32 %9754 to i64, !dbg !343 + %9756 = or disjoint i64 %9755, 4611686293338849280, !dbg !343 + %9757 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 0, !dbg !343 + %9758 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 1, !dbg !343 + %9759 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 2, !dbg !343 + %9760 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 3, !dbg !343 + %9761 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 4, !dbg !343 + %9762 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 5, !dbg !343 + %9763 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 6, !dbg !343 + %9764 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 7, !dbg !343 + %9765 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 8, !dbg !343 + %9766 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 9, !dbg !343 + %9767 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 10, !dbg !343 + %9768 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 11, !dbg !343 + %9769 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 12, !dbg !343 + %9770 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 13, !dbg !343 + %9771 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 14, !dbg !343 + %9772 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 15, !dbg !343 + %9773 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 16, !dbg !343 + %9774 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 17, !dbg !343 + %9775 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 18, !dbg !343 + %9776 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 19, !dbg !343 + %9777 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 20, !dbg !343 + %9778 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 21, !dbg !343 + %9779 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 22, !dbg !343 + %9780 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 23, !dbg !343 + %9781 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 24, !dbg !343 + %9782 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 25, !dbg !343 + %9783 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 26, !dbg !343 + %9784 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 27, !dbg !343 + %9785 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 28, !dbg !343 + %9786 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 29, !dbg !343 + %9787 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 30, !dbg !343 + %9788 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 31, !dbg !343 + %9789 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 32, !dbg !343 + %9790 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 33, !dbg !343 + %9791 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 34, !dbg !343 + %9792 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 35, !dbg !343 + %9793 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 36, !dbg !343 + %9794 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 37, !dbg !343 + %9795 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 38, !dbg !343 + %9796 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 39, !dbg !343 + %9797 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 40, !dbg !343 + %9798 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 41, !dbg !343 + %9799 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 42, !dbg !343 + %9800 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 43, !dbg !343 + %9801 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 44, !dbg !343 + %9802 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 45, !dbg !343 + %9803 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 46, !dbg !343 + %9804 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 47, !dbg !343 + %9805 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 48, !dbg !343 + %9806 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 49, !dbg !343 + %9807 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 50, !dbg !343 + %9808 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 51, !dbg !343 + %9809 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 52, !dbg !343 + %9810 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 53, !dbg !343 + %9811 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 54, !dbg !343 + %9812 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 55, !dbg !343 + %9813 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 56, !dbg !343 + %9814 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 57, !dbg !343 + %9815 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 58, !dbg !343 + %9816 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 59, !dbg !343 + %9817 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 60, !dbg !343 + %9818 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 61, !dbg !343 + %9819 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 62, !dbg !343 + %9820 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9751, 63, !dbg !343 + %9821 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %9757, float %9758, float %9759, float %9760, float %9761, float %9762, float %9763, float %9764, float %9765, float %9766, float %9767, float %9768, float %9769, float %9770, float %9771, float %9772, float %9773, float %9774, float %9775, float %9776, float %9777, float %9778, float %9779, float %9780, float %9781, float %9782, float %9783, float %9784, float %9785, float %9786, float %9787, float %9788, float %9789, float %9790, float %9791, float %9792, float %9793, float %9794, float %9795, float %9796, float %9797, float %9798, float %9799, float %9800, float %9801, float %9802, float %9803, float %9804, float %9805, float %9806, float %9807, float %9808, float %9809, float %9810, float %9811, float %9812, float %9813, float %9814, float %9815, float %9816, float %9817, float %9818, float %9819, float %9820, i32 %9602, i32 %9603, i32 %9604, i32 %9605, i64 %9756, i1 true) #3, !dbg !343 + %9822 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 0, !dbg !343 + %9823 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 1, !dbg !343 + %9824 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 2, !dbg !343 + %9825 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 3, !dbg !343 + %9826 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 4, !dbg !343 + %9827 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 5, !dbg !343 + %9828 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 6, !dbg !343 + %9829 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 7, !dbg !343 + %9830 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 8, !dbg !343 + %9831 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 9, !dbg !343 + %9832 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 10, !dbg !343 + %9833 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 11, !dbg !343 + %9834 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 12, !dbg !343 + %9835 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 13, !dbg !343 + %9836 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 14, !dbg !343 + %9837 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 15, !dbg !343 + %9838 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 16, !dbg !343 + %9839 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 17, !dbg !343 + %9840 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 18, !dbg !343 + %9841 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 19, !dbg !343 + %9842 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 20, !dbg !343 + %9843 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 21, !dbg !343 + %9844 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 22, !dbg !343 + %9845 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 23, !dbg !343 + %9846 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 24, !dbg !343 + %9847 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 25, !dbg !343 + %9848 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 26, !dbg !343 + %9849 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 27, !dbg !343 + %9850 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 28, !dbg !343 + %9851 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 29, !dbg !343 + %9852 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 30, !dbg !343 + %9853 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 31, !dbg !343 + %9854 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 32, !dbg !343 + %9855 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 33, !dbg !343 + %9856 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 34, !dbg !343 + %9857 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 35, !dbg !343 + %9858 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 36, !dbg !343 + %9859 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 37, !dbg !343 + %9860 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 38, !dbg !343 + %9861 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 39, !dbg !343 + %9862 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 40, !dbg !343 + %9863 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 41, !dbg !343 + %9864 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 42, !dbg !343 + %9865 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 43, !dbg !343 + %9866 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 44, !dbg !343 + %9867 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 45, !dbg !343 + %9868 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 46, !dbg !343 + %9869 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 47, !dbg !343 + %9870 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 48, !dbg !343 + %9871 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 49, !dbg !343 + %9872 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 50, !dbg !343 + %9873 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 51, !dbg !343 + %9874 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 52, !dbg !343 + %9875 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 53, !dbg !343 + %9876 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 54, !dbg !343 + %9877 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 55, !dbg !343 + %9878 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 56, !dbg !343 + %9879 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 57, !dbg !343 + %9880 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 58, !dbg !343 + %9881 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 59, !dbg !343 + %9882 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 60, !dbg !343 + %9883 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 61, !dbg !343 + %9884 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 62, !dbg !343 + %9885 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9821, 63, !dbg !343 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !343 + %9886 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %8800, !dbg !329 + %9887 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5179, !dbg !329 + %9888 = load float, ptr addrspace(3) %9887, align 8, !dbg !329 + %9889 = getelementptr inbounds nuw i8, ptr addrspace(3) %9887, i32 4, !dbg !329 + %9890 = load float, ptr addrspace(3) %9889, align 4, !dbg !329 + %9891 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5185, !dbg !329 + %9892 = load float, ptr addrspace(3) %9891, align 8, !dbg !329 + %9893 = getelementptr inbounds nuw i8, ptr addrspace(3) %9891, i32 4, !dbg !329 + %9894 = load float, ptr addrspace(3) %9893, align 4, !dbg !329 + %9895 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5191, !dbg !329 + %9896 = load float, ptr addrspace(3) %9895, align 8, !dbg !329 + %9897 = getelementptr inbounds nuw i8, ptr addrspace(3) %9895, i32 4, !dbg !329 + %9898 = load float, ptr addrspace(3) %9897, align 4, !dbg !329 + %9899 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5197, !dbg !329 + %9900 = load float, ptr addrspace(3) %9899, align 8, !dbg !329 + %9901 = getelementptr inbounds nuw i8, ptr addrspace(3) %9899, i32 4, !dbg !329 + %9902 = load float, ptr addrspace(3) %9901, align 4, !dbg !329 + %9903 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5203, !dbg !329 + %9904 = load float, ptr addrspace(3) %9903, align 8, !dbg !329 + %9905 = getelementptr inbounds nuw i8, ptr addrspace(3) %9903, i32 4, !dbg !329 + %9906 = load float, ptr addrspace(3) %9905, align 4, !dbg !329 + %9907 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5209, !dbg !329 + %9908 = load float, ptr addrspace(3) %9907, align 8, !dbg !329 + %9909 = getelementptr inbounds nuw i8, ptr addrspace(3) %9907, i32 4, !dbg !329 + %9910 = load float, ptr addrspace(3) %9909, align 4, !dbg !329 + %9911 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5215, !dbg !329 + %9912 = load float, ptr addrspace(3) %9911, align 8, !dbg !329 + %9913 = getelementptr inbounds nuw i8, ptr addrspace(3) %9911, i32 4, !dbg !329 + %9914 = load float, ptr addrspace(3) %9913, align 4, !dbg !329 + %9915 = getelementptr inbounds nuw i8, ptr addrspace(3) %9886, i32 %5221, !dbg !329 + %9916 = load float, ptr addrspace(3) %9915, align 8, !dbg !329 + %9917 = getelementptr inbounds nuw i8, ptr addrspace(3) %9915, i32 4, !dbg !329 + %9918 = load float, ptr addrspace(3) %9917, align 4, !dbg !329 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !344 + %9919 = add i32 %8868, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %9920 = lshr exact i32 %9919, 4, !dbg !344 + %9921 = and i32 %9920, 16383, !dbg !344 + %9922 = zext nneg i32 %9921 to i64, !dbg !344 + %9923 = or disjoint i64 %9922, 4611686293372403712, !dbg !344 + %9924 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $32, $33, 0, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l"(i64 %9923, i64 %9610) #3, !dbg !344 + %9925 = add i32 %8880, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %9926 = lshr exact i32 %9925, 4, !dbg !344 + %9927 = and i32 %9926, 16383, !dbg !344 + %9928 = zext nneg i32 %9927 to i64, !dbg !344 + %9929 = or disjoint i64 %9928, 4611686293372403712, !dbg !344 + %9930 = add i32 %9606, 32, !dbg !344 + %9931 = lshr exact i32 %9930, 4, !dbg !344 + %9932 = and i32 %9931, 16383, !dbg !344 + %9933 = zext nneg i32 %9932 to i64, !dbg !344 + %9934 = or disjoint i64 %9933, 4611686293338849280, !dbg !344 + %9935 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 0, !dbg !344 + %9936 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 1, !dbg !344 + %9937 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 2, !dbg !344 + %9938 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 3, !dbg !344 + %9939 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 4, !dbg !344 + %9940 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 5, !dbg !344 + %9941 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 6, !dbg !344 + %9942 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 7, !dbg !344 + %9943 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 8, !dbg !344 + %9944 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 9, !dbg !344 + %9945 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 10, !dbg !344 + %9946 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 11, !dbg !344 + %9947 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 12, !dbg !344 + %9948 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 13, !dbg !344 + %9949 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 14, !dbg !344 + %9950 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 15, !dbg !344 + %9951 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 16, !dbg !344 + %9952 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 17, !dbg !344 + %9953 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 18, !dbg !344 + %9954 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 19, !dbg !344 + %9955 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 20, !dbg !344 + %9956 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 21, !dbg !344 + %9957 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 22, !dbg !344 + %9958 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 23, !dbg !344 + %9959 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 24, !dbg !344 + %9960 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 25, !dbg !344 + %9961 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 26, !dbg !344 + %9962 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 27, !dbg !344 + %9963 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 28, !dbg !344 + %9964 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 29, !dbg !344 + %9965 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 30, !dbg !344 + %9966 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9924, 31, !dbg !344 + %9967 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9935, float %9936, float %9937, float %9938, float %9939, float %9940, float %9941, float %9942, float %9943, float %9944, float %9945, float %9946, float %9947, float %9948, float %9949, float %9950, float %9951, float %9952, float %9953, float %9954, float %9955, float %9956, float %9957, float %9958, float %9959, float %9960, float %9961, float %9962, float %9963, float %9964, float %9965, float %9966, i64 %9929, i64 %9934, i1 true) #3, !dbg !344 + %9968 = add i32 %8924, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %9969 = lshr exact i32 %9968, 4, !dbg !344 + %9970 = and i32 %9969, 16383, !dbg !344 + %9971 = zext nneg i32 %9970 to i64, !dbg !344 + %9972 = or disjoint i64 %9971, 4611686293372403712, !dbg !344 + %9973 = add i32 %9606, 64, !dbg !344 + %9974 = lshr exact i32 %9973, 4, !dbg !344 + %9975 = and i32 %9974, 16383, !dbg !344 + %9976 = zext nneg i32 %9975 to i64, !dbg !344 + %9977 = or disjoint i64 %9976, 4611686293338849280, !dbg !344 + %9978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 0, !dbg !344 + %9979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 1, !dbg !344 + %9980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 2, !dbg !344 + %9981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 3, !dbg !344 + %9982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 4, !dbg !344 + %9983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 5, !dbg !344 + %9984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 6, !dbg !344 + %9985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 7, !dbg !344 + %9986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 8, !dbg !344 + %9987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 9, !dbg !344 + %9988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 10, !dbg !344 + %9989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 11, !dbg !344 + %9990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 12, !dbg !344 + %9991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 13, !dbg !344 + %9992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 14, !dbg !344 + %9993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 15, !dbg !344 + %9994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 16, !dbg !344 + %9995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 17, !dbg !344 + %9996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 18, !dbg !344 + %9997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 19, !dbg !344 + %9998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 20, !dbg !344 + %9999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 21, !dbg !344 + %10000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 22, !dbg !344 + %10001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 23, !dbg !344 + %10002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 24, !dbg !344 + %10003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 25, !dbg !344 + %10004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 26, !dbg !344 + %10005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 27, !dbg !344 + %10006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 28, !dbg !344 + %10007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 29, !dbg !344 + %10008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 30, !dbg !344 + %10009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %9967, 31, !dbg !344 + %10010 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %9978, float %9979, float %9980, float %9981, float %9982, float %9983, float %9984, float %9985, float %9986, float %9987, float %9988, float %9989, float %9990, float %9991, float %9992, float %9993, float %9994, float %9995, float %9996, float %9997, float %9998, float %9999, float %10000, float %10001, float %10002, float %10003, float %10004, float %10005, float %10006, float %10007, float %10008, float %10009, i64 %9972, i64 %9977, i1 true) #3, !dbg !344 + %10011 = add i32 %8968, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %10012 = lshr exact i32 %10011, 4, !dbg !344 + %10013 = and i32 %10012, 16383, !dbg !344 + %10014 = zext nneg i32 %10013 to i64, !dbg !344 + %10015 = or disjoint i64 %10014, 4611686293372403712, !dbg !344 + %10016 = add i32 %9606, 96, !dbg !344 + %10017 = lshr exact i32 %10016, 4, !dbg !344 + %10018 = and i32 %10017, 16383, !dbg !344 + %10019 = zext nneg i32 %10018 to i64, !dbg !344 + %10020 = or disjoint i64 %10019, 4611686293338849280, !dbg !344 + %10021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 0, !dbg !344 + %10022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 1, !dbg !344 + %10023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 2, !dbg !344 + %10024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 3, !dbg !344 + %10025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 4, !dbg !344 + %10026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 5, !dbg !344 + %10027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 6, !dbg !344 + %10028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 7, !dbg !344 + %10029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 8, !dbg !344 + %10030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 9, !dbg !344 + %10031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 10, !dbg !344 + %10032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 11, !dbg !344 + %10033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 12, !dbg !344 + %10034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 13, !dbg !344 + %10035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 14, !dbg !344 + %10036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 15, !dbg !344 + %10037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 16, !dbg !344 + %10038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 17, !dbg !344 + %10039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 18, !dbg !344 + %10040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 19, !dbg !344 + %10041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 20, !dbg !344 + %10042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 21, !dbg !344 + %10043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 22, !dbg !344 + %10044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 23, !dbg !344 + %10045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 24, !dbg !344 + %10046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 25, !dbg !344 + %10047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 26, !dbg !344 + %10048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 27, !dbg !344 + %10049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 28, !dbg !344 + %10050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 29, !dbg !344 + %10051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 30, !dbg !344 + %10052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10010, 31, !dbg !344 + %10053 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10021, float %10022, float %10023, float %10024, float %10025, float %10026, float %10027, float %10028, float %10029, float %10030, float %10031, float %10032, float %10033, float %10034, float %10035, float %10036, float %10037, float %10038, float %10039, float %10040, float %10041, float %10042, float %10043, float %10044, float %10045, float %10046, float %10047, float %10048, float %10049, float %10050, float %10051, float %10052, i64 %10015, i64 %10020, i1 true) #3, !dbg !344 + %10054 = add i32 %9012, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %10055 = lshr exact i32 %10054, 4, !dbg !344 + %10056 = and i32 %10055, 16383, !dbg !344 + %10057 = zext nneg i32 %10056 to i64, !dbg !344 + %10058 = or disjoint i64 %10057, 4611686293372403712, !dbg !344 + %10059 = add i32 %9606, 8192, !dbg !344 + %10060 = lshr exact i32 %10059, 4, !dbg !344 + %10061 = and i32 %10060, 16383, !dbg !344 + %10062 = zext nneg i32 %10061 to i64, !dbg !344 + %10063 = or disjoint i64 %10062, 4611686293338849280, !dbg !344 + %10064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 0, !dbg !344 + %10065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 1, !dbg !344 + %10066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 2, !dbg !344 + %10067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 3, !dbg !344 + %10068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 4, !dbg !344 + %10069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 5, !dbg !344 + %10070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 6, !dbg !344 + %10071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 7, !dbg !344 + %10072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 8, !dbg !344 + %10073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 9, !dbg !344 + %10074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 10, !dbg !344 + %10075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 11, !dbg !344 + %10076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 12, !dbg !344 + %10077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 13, !dbg !344 + %10078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 14, !dbg !344 + %10079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 15, !dbg !344 + %10080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 16, !dbg !344 + %10081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 17, !dbg !344 + %10082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 18, !dbg !344 + %10083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 19, !dbg !344 + %10084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 20, !dbg !344 + %10085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 21, !dbg !344 + %10086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 22, !dbg !344 + %10087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 23, !dbg !344 + %10088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 24, !dbg !344 + %10089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 25, !dbg !344 + %10090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 26, !dbg !344 + %10091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 27, !dbg !344 + %10092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 28, !dbg !344 + %10093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 29, !dbg !344 + %10094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 30, !dbg !344 + %10095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10053, 31, !dbg !344 + %10096 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10064, float %10065, float %10066, float %10067, float %10068, float %10069, float %10070, float %10071, float %10072, float %10073, float %10074, float %10075, float %10076, float %10077, float %10078, float %10079, float %10080, float %10081, float %10082, float %10083, float %10084, float %10085, float %10086, float %10087, float %10088, float %10089, float %10090, float %10091, float %10092, float %10093, float %10094, float %10095, i64 %10058, i64 %10063, i1 true) #3, !dbg !344 + %10097 = add i32 %9056, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %10098 = lshr exact i32 %10097, 4, !dbg !344 + %10099 = and i32 %10098, 16383, !dbg !344 + %10100 = zext nneg i32 %10099 to i64, !dbg !344 + %10101 = or disjoint i64 %10100, 4611686293372403712, !dbg !344 + %10102 = add i32 %9606, 8224, !dbg !344 + %10103 = lshr exact i32 %10102, 4, !dbg !344 + %10104 = and i32 %10103, 16383, !dbg !344 + %10105 = zext nneg i32 %10104 to i64, !dbg !344 + %10106 = or disjoint i64 %10105, 4611686293338849280, !dbg !344 + %10107 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 0, !dbg !344 + %10108 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 1, !dbg !344 + %10109 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 2, !dbg !344 + %10110 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 3, !dbg !344 + %10111 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 4, !dbg !344 + %10112 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 5, !dbg !344 + %10113 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 6, !dbg !344 + %10114 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 7, !dbg !344 + %10115 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 8, !dbg !344 + %10116 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 9, !dbg !344 + %10117 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 10, !dbg !344 + %10118 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 11, !dbg !344 + %10119 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 12, !dbg !344 + %10120 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 13, !dbg !344 + %10121 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 14, !dbg !344 + %10122 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 15, !dbg !344 + %10123 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 16, !dbg !344 + %10124 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 17, !dbg !344 + %10125 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 18, !dbg !344 + %10126 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 19, !dbg !344 + %10127 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 20, !dbg !344 + %10128 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 21, !dbg !344 + %10129 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 22, !dbg !344 + %10130 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 23, !dbg !344 + %10131 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 24, !dbg !344 + %10132 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 25, !dbg !344 + %10133 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 26, !dbg !344 + %10134 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 27, !dbg !344 + %10135 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 28, !dbg !344 + %10136 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 29, !dbg !344 + %10137 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 30, !dbg !344 + %10138 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10096, 31, !dbg !344 + %10139 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10107, float %10108, float %10109, float %10110, float %10111, float %10112, float %10113, float %10114, float %10115, float %10116, float %10117, float %10118, float %10119, float %10120, float %10121, float %10122, float %10123, float %10124, float %10125, float %10126, float %10127, float %10128, float %10129, float %10130, float %10131, float %10132, float %10133, float %10134, float %10135, float %10136, float %10137, float %10138, i64 %10101, i64 %10106, i1 true) #3, !dbg !344 + %10140 = add i32 %9100, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %10141 = lshr exact i32 %10140, 4, !dbg !344 + %10142 = and i32 %10141, 16383, !dbg !344 + %10143 = zext nneg i32 %10142 to i64, !dbg !344 + %10144 = or disjoint i64 %10143, 4611686293372403712, !dbg !344 + %10145 = add i32 %9606, 8256, !dbg !344 + %10146 = lshr exact i32 %10145, 4, !dbg !344 + %10147 = and i32 %10146, 16383, !dbg !344 + %10148 = zext nneg i32 %10147 to i64, !dbg !344 + %10149 = or disjoint i64 %10148, 4611686293338849280, !dbg !344 + %10150 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 0, !dbg !344 + %10151 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 1, !dbg !344 + %10152 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 2, !dbg !344 + %10153 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 3, !dbg !344 + %10154 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 4, !dbg !344 + %10155 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 5, !dbg !344 + %10156 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 6, !dbg !344 + %10157 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 7, !dbg !344 + %10158 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 8, !dbg !344 + %10159 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 9, !dbg !344 + %10160 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 10, !dbg !344 + %10161 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 11, !dbg !344 + %10162 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 12, !dbg !344 + %10163 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 13, !dbg !344 + %10164 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 14, !dbg !344 + %10165 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 15, !dbg !344 + %10166 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 16, !dbg !344 + %10167 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 17, !dbg !344 + %10168 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 18, !dbg !344 + %10169 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 19, !dbg !344 + %10170 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 20, !dbg !344 + %10171 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 21, !dbg !344 + %10172 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 22, !dbg !344 + %10173 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 23, !dbg !344 + %10174 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 24, !dbg !344 + %10175 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 25, !dbg !344 + %10176 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 26, !dbg !344 + %10177 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 27, !dbg !344 + %10178 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 28, !dbg !344 + %10179 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 29, !dbg !344 + %10180 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 30, !dbg !344 + %10181 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10139, 31, !dbg !344 + %10182 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10150, float %10151, float %10152, float %10153, float %10154, float %10155, float %10156, float %10157, float %10158, float %10159, float %10160, float %10161, float %10162, float %10163, float %10164, float %10165, float %10166, float %10167, float %10168, float %10169, float %10170, float %10171, float %10172, float %10173, float %10174, float %10175, float %10176, float %10177, float %10178, float %10179, float %10180, float %10181, i64 %10144, i64 %10149, i1 true) #3, !dbg !344 + %10183 = add i32 %9144, ptrtoint (ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096) to i32), !dbg !344 + %10184 = lshr exact i32 %10183, 4, !dbg !344 + %10185 = and i32 %10184, 16383, !dbg !344 + %10186 = zext nneg i32 %10185 to i64, !dbg !344 + %10187 = or disjoint i64 %10186, 4611686293372403712, !dbg !344 + %10188 = add i32 %9606, 8288, !dbg !344 + %10189 = lshr exact i32 %10188, 4, !dbg !344 + %10190 = and i32 %10189, 16383, !dbg !344 + %10191 = zext nneg i32 %10190 to i64, !dbg !344 + %10192 = or disjoint i64 %10191, 4611686293338849280, !dbg !344 + %10193 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 0, !dbg !344 + %10194 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 1, !dbg !344 + %10195 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 2, !dbg !344 + %10196 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 3, !dbg !344 + %10197 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 4, !dbg !344 + %10198 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 5, !dbg !344 + %10199 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 6, !dbg !344 + %10200 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 7, !dbg !344 + %10201 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 8, !dbg !344 + %10202 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 9, !dbg !344 + %10203 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 10, !dbg !344 + %10204 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 11, !dbg !344 + %10205 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 12, !dbg !344 + %10206 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 13, !dbg !344 + %10207 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 14, !dbg !344 + %10208 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 15, !dbg !344 + %10209 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 16, !dbg !344 + %10210 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 17, !dbg !344 + %10211 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 18, !dbg !344 + %10212 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 19, !dbg !344 + %10213 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 20, !dbg !344 + %10214 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 21, !dbg !344 + %10215 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 22, !dbg !344 + %10216 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 23, !dbg !344 + %10217 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 24, !dbg !344 + %10218 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 25, !dbg !344 + %10219 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 26, !dbg !344 + %10220 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 27, !dbg !344 + %10221 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 28, !dbg !344 + %10222 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 29, !dbg !344 + %10223 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 30, !dbg !344 + %10224 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10182, 31, !dbg !344 + %10225 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, $66, 1, 1, 0, 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,b"(float %10193, float %10194, float %10195, float %10196, float %10197, float %10198, float %10199, float %10200, float %10201, float %10202, float %10203, float %10204, float %10205, float %10206, float %10207, float %10208, float %10209, float %10210, float %10211, float %10212, float %10213, float %10214, float %10215, float %10216, float %10217, float %10218, float %10219, float %10220, float %10221, float %10222, float %10223, float %10224, i64 %10187, i64 %10192, i1 true) #3, !dbg !344 + %10226 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 0, !dbg !344 + %10227 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 1, !dbg !344 + %10228 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 2, !dbg !344 + %10229 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 3, !dbg !344 + %10230 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 4, !dbg !344 + %10231 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 5, !dbg !344 + %10232 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 6, !dbg !344 + %10233 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 7, !dbg !344 + %10234 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 8, !dbg !344 + %10235 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 9, !dbg !344 + %10236 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 10, !dbg !344 + %10237 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 11, !dbg !344 + %10238 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 12, !dbg !344 + %10239 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 13, !dbg !344 + %10240 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 14, !dbg !344 + %10241 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 15, !dbg !344 + %10242 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 16, !dbg !344 + %10243 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 17, !dbg !344 + %10244 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 18, !dbg !344 + %10245 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 19, !dbg !344 + %10246 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 20, !dbg !344 + %10247 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 21, !dbg !344 + %10248 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 22, !dbg !344 + %10249 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 23, !dbg !344 + %10250 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 24, !dbg !344 + %10251 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 25, !dbg !344 + %10252 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 26, !dbg !344 + %10253 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 27, !dbg !344 + %10254 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 28, !dbg !344 + %10255 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 29, !dbg !344 + %10256 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 30, !dbg !344 + %10257 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10225, 31, !dbg !344 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !344 + %10258 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37"(float %10226, float %10227, float %10228, float %10229, float %10230, float %10231, float %10232, float %10233, float %10234, float %10235, float %10236, float %10237, float %10238, float %10239, float %10240, float %10241, float %10242, float %10243, float %10244, float %10245, float %10246, float %10247, float %10248, float %10249, float %10250, float %10251, float %10252, float %10253, float %10254, float %10255, float %10256, float %10257, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 132096), i32 0, i32 0, ptr addrspace(3) %9541, i32 0, i32 0) #3, !dbg !344 + %10259 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 0, !dbg !344 + %10260 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 1, !dbg !344 + %10261 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 2, !dbg !344 + %10262 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 3, !dbg !344 + %10263 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 4, !dbg !344 + %10264 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 5, !dbg !344 + %10265 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 6, !dbg !344 + %10266 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 7, !dbg !344 + %10267 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 8, !dbg !344 + %10268 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 9, !dbg !344 + %10269 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 10, !dbg !344 + %10270 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 11, !dbg !344 + %10271 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 12, !dbg !344 + %10272 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 13, !dbg !344 + %10273 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 14, !dbg !344 + %10274 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 15, !dbg !344 + %10275 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 16, !dbg !344 + %10276 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 17, !dbg !344 + %10277 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 18, !dbg !344 + %10278 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 19, !dbg !344 + %10279 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 20, !dbg !344 + %10280 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 21, !dbg !344 + %10281 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 22, !dbg !344 + %10282 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 23, !dbg !344 + %10283 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 24, !dbg !344 + %10284 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 25, !dbg !344 + %10285 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 26, !dbg !344 + %10286 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 27, !dbg !344 + %10287 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 28, !dbg !344 + %10288 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 29, !dbg !344 + %10289 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 30, !dbg !344 + %10290 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, ptr addrspace(3), i32, i32, ptr addrspace(3), i32, i32 } %10258, 31, !dbg !344 + %10291 = fsub float %10259, %9888, !dbg !345 + %10292 = fsub float %10260, %9890, !dbg !345 + %10293 = fsub float %10261, %9888, !dbg !345 + %10294 = fsub float %10262, %9890, !dbg !345 + %10295 = fsub float %10263, %9892, !dbg !345 + %10296 = fsub float %10264, %9894, !dbg !345 + %10297 = fsub float %10265, %9892, !dbg !345 + %10298 = fsub float %10266, %9894, !dbg !345 + %10299 = fsub float %10267, %9896, !dbg !345 + %10300 = fsub float %10268, %9898, !dbg !345 + %10301 = fsub float %10269, %9896, !dbg !345 + %10302 = fsub float %10270, %9898, !dbg !345 + %10303 = fsub float %10271, %9900, !dbg !345 + %10304 = fsub float %10272, %9902, !dbg !345 + %10305 = fsub float %10273, %9900, !dbg !345 + %10306 = fsub float %10274, %9902, !dbg !345 + %10307 = fsub float %10275, %9904, !dbg !345 + %10308 = fsub float %10276, %9906, !dbg !345 + %10309 = fsub float %10277, %9904, !dbg !345 + %10310 = fsub float %10278, %9906, !dbg !345 + %10311 = fsub float %10279, %9908, !dbg !345 + %10312 = fsub float %10280, %9910, !dbg !345 + %10313 = fsub float %10281, %9908, !dbg !345 + %10314 = fsub float %10282, %9910, !dbg !345 + %10315 = fsub float %10283, %9912, !dbg !345 + %10316 = fsub float %10284, %9914, !dbg !345 + %10317 = fsub float %10285, %9912, !dbg !345 + %10318 = fsub float %10286, %9914, !dbg !345 + %10319 = fsub float %10287, %9916, !dbg !345 + %10320 = fsub float %10288, %9918, !dbg !345 + %10321 = fsub float %10289, %9916, !dbg !345 + %10322 = fsub float %10290, %9918, !dbg !345 + %10323 = fmul float %.0.i, %10291, !dbg !346 + %10324 = fmul float %.0.i1230, %10292, !dbg !346 + %10325 = fmul float %.0.i1233, %10293, !dbg !346 + %10326 = fmul float %.0.i1236, %10294, !dbg !346 + %10327 = fmul float %.0.i1239, %10295, !dbg !346 + %10328 = fmul float %.0.i1242, %10296, !dbg !346 + %10329 = fmul float %.0.i1245, %10297, !dbg !346 + %10330 = fmul float %.0.i1248, %10298, !dbg !346 + %10331 = fmul float %.0.i1251, %10299, !dbg !346 + %10332 = fmul float %.0.i1254, %10300, !dbg !346 + %10333 = fmul float %.0.i1257, %10301, !dbg !346 + %10334 = fmul float %.0.i1260, %10302, !dbg !346 + %10335 = fmul float %.0.i1263, %10303, !dbg !346 + %10336 = fmul float %.0.i1266, %10304, !dbg !346 + %10337 = fmul float %.0.i1269, %10305, !dbg !346 + %10338 = fmul float %.0.i1272, %10306, !dbg !346 + %10339 = fmul float %.0.i1275, %10307, !dbg !346 + %10340 = fmul float %.0.i1278, %10308, !dbg !346 + %10341 = fmul float %.0.i1281, %10309, !dbg !346 + %10342 = fmul float %.0.i1284, %10310, !dbg !346 + %10343 = fmul float %.0.i1287, %10311, !dbg !346 + %10344 = fmul float %.0.i1290, %10312, !dbg !346 + %10345 = fmul float %.0.i1293, %10313, !dbg !346 + %10346 = fmul float %.0.i1296, %10314, !dbg !346 + %10347 = fmul float %.0.i1299, %10315, !dbg !346 + %10348 = fmul float %.0.i1302, %10316, !dbg !346 + %10349 = fmul float %.0.i1305, %10317, !dbg !346 + %10350 = fmul float %.0.i1308, %10318, !dbg !346 + %10351 = fmul float %.0.i1311, %10319, !dbg !346 + %10352 = fmul float %.0.i1314, %10320, !dbg !346 + %10353 = fmul float %.0.i1317, %10321, !dbg !346 + %10354 = fmul float %.0.i1320, %10322, !dbg !346 + %10355 = fptrunc float %10323 to bfloat, !dbg !347 + %10356 = select i1 %8782, bfloat %10355, bfloat 0xR0000, !dbg !348 + %10357 = fptrunc float %10324 to bfloat, !dbg !347 + %10358 = select i1 %8783, bfloat %10357, bfloat 0xR0000, !dbg !348 + %10359 = fptrunc float %10325 to bfloat, !dbg !347 + %10360 = select i1 %8782, bfloat %10359, bfloat 0xR0000, !dbg !348 + %10361 = fptrunc float %10326 to bfloat, !dbg !347 + %10362 = select i1 %8783, bfloat %10361, bfloat 0xR0000, !dbg !348 + %10363 = fptrunc float %10327 to bfloat, !dbg !347 + %10364 = select i1 %8784, bfloat %10363, bfloat 0xR0000, !dbg !348 + %10365 = fptrunc float %10328 to bfloat, !dbg !347 + %10366 = select i1 %8785, bfloat %10365, bfloat 0xR0000, !dbg !348 + %10367 = fptrunc float %10329 to bfloat, !dbg !347 + %10368 = select i1 %8784, bfloat %10367, bfloat 0xR0000, !dbg !348 + %10369 = fptrunc float %10330 to bfloat, !dbg !347 + %10370 = select i1 %8785, bfloat %10369, bfloat 0xR0000, !dbg !348 + %10371 = fptrunc float %10331 to bfloat, !dbg !347 + %10372 = select i1 %8786, bfloat %10371, bfloat 0xR0000, !dbg !348 + %10373 = fptrunc float %10332 to bfloat, !dbg !347 + %10374 = select i1 %8787, bfloat %10373, bfloat 0xR0000, !dbg !348 + %10375 = fptrunc float %10333 to bfloat, !dbg !347 + %10376 = select i1 %8786, bfloat %10375, bfloat 0xR0000, !dbg !348 + %10377 = fptrunc float %10334 to bfloat, !dbg !347 + %10378 = select i1 %8787, bfloat %10377, bfloat 0xR0000, !dbg !348 + %10379 = fptrunc float %10335 to bfloat, !dbg !347 + %10380 = select i1 %8788, bfloat %10379, bfloat 0xR0000, !dbg !348 + %10381 = fptrunc float %10336 to bfloat, !dbg !347 + %10382 = select i1 %8789, bfloat %10381, bfloat 0xR0000, !dbg !348 + %10383 = fptrunc float %10337 to bfloat, !dbg !347 + %10384 = select i1 %8788, bfloat %10383, bfloat 0xR0000, !dbg !348 + %10385 = fptrunc float %10338 to bfloat, !dbg !347 + %10386 = select i1 %8789, bfloat %10385, bfloat 0xR0000, !dbg !348 + %10387 = fptrunc float %10339 to bfloat, !dbg !347 + %10388 = select i1 %8790, bfloat %10387, bfloat 0xR0000, !dbg !348 + %10389 = fptrunc float %10340 to bfloat, !dbg !347 + %10390 = select i1 %8791, bfloat %10389, bfloat 0xR0000, !dbg !348 + %10391 = fptrunc float %10341 to bfloat, !dbg !347 + %10392 = select i1 %8790, bfloat %10391, bfloat 0xR0000, !dbg !348 + %10393 = fptrunc float %10342 to bfloat, !dbg !347 + %10394 = select i1 %8791, bfloat %10393, bfloat 0xR0000, !dbg !348 + %10395 = fptrunc float %10343 to bfloat, !dbg !347 + %10396 = select i1 %8792, bfloat %10395, bfloat 0xR0000, !dbg !348 + %10397 = fptrunc float %10344 to bfloat, !dbg !347 + %10398 = select i1 %8793, bfloat %10397, bfloat 0xR0000, !dbg !348 + %10399 = fptrunc float %10345 to bfloat, !dbg !347 + %10400 = select i1 %8792, bfloat %10399, bfloat 0xR0000, !dbg !348 + %10401 = fptrunc float %10346 to bfloat, !dbg !347 + %10402 = select i1 %8793, bfloat %10401, bfloat 0xR0000, !dbg !348 + %10403 = fptrunc float %10347 to bfloat, !dbg !347 + %10404 = select i1 %8794, bfloat %10403, bfloat 0xR0000, !dbg !348 + %10405 = fptrunc float %10348 to bfloat, !dbg !347 + %10406 = select i1 %8795, bfloat %10405, bfloat 0xR0000, !dbg !348 + %10407 = fptrunc float %10349 to bfloat, !dbg !347 + %10408 = select i1 %8794, bfloat %10407, bfloat 0xR0000, !dbg !348 + %10409 = fptrunc float %10350 to bfloat, !dbg !347 + %10410 = select i1 %8795, bfloat %10409, bfloat 0xR0000, !dbg !348 + %10411 = fptrunc float %10351 to bfloat, !dbg !347 + %10412 = select i1 %8796, bfloat %10411, bfloat 0xR0000, !dbg !348 + %10413 = fptrunc float %10352 to bfloat, !dbg !347 + %10414 = select i1 %8797, bfloat %10413, bfloat 0xR0000, !dbg !348 + %10415 = fptrunc float %10353 to bfloat, !dbg !347 + %10416 = select i1 %8796, bfloat %10415, bfloat 0xR0000, !dbg !348 + %10417 = fptrunc float %10354 to bfloat, !dbg !347 + %10418 = select i1 %8797, bfloat %10417, bfloat 0xR0000, !dbg !348 + %10419 = insertelement <2 x bfloat> poison, bfloat %10356, i64 0, !dbg !349 + %10420 = insertelement <2 x bfloat> %10419, bfloat %10358, i64 1, !dbg !349 + %10421 = bitcast <2 x bfloat> %10420 to i32, !dbg !349 + %10422 = insertelement <2 x bfloat> poison, bfloat %10360, i64 0, !dbg !349 + %10423 = insertelement <2 x bfloat> %10422, bfloat %10362, i64 1, !dbg !349 + %10424 = bitcast <2 x bfloat> %10423 to i32, !dbg !349 + %10425 = insertelement <2 x bfloat> poison, bfloat %10364, i64 0, !dbg !349 + %10426 = insertelement <2 x bfloat> %10425, bfloat %10366, i64 1, !dbg !349 + %10427 = bitcast <2 x bfloat> %10426 to i32, !dbg !349 + %10428 = insertelement <2 x bfloat> poison, bfloat %10368, i64 0, !dbg !349 + %10429 = insertelement <2 x bfloat> %10428, bfloat %10370, i64 1, !dbg !349 + %10430 = bitcast <2 x bfloat> %10429 to i32, !dbg !349 + %10431 = insertelement <2 x bfloat> poison, bfloat %10372, i64 0, !dbg !349 + %10432 = insertelement <2 x bfloat> %10431, bfloat %10374, i64 1, !dbg !349 + %10433 = bitcast <2 x bfloat> %10432 to i32, !dbg !349 + %10434 = insertelement <2 x bfloat> poison, bfloat %10376, i64 0, !dbg !349 + %10435 = insertelement <2 x bfloat> %10434, bfloat %10378, i64 1, !dbg !349 + %10436 = bitcast <2 x bfloat> %10435 to i32, !dbg !349 + %10437 = insertelement <2 x bfloat> poison, bfloat %10380, i64 0, !dbg !349 + %10438 = insertelement <2 x bfloat> %10437, bfloat %10382, i64 1, !dbg !349 + %10439 = bitcast <2 x bfloat> %10438 to i32, !dbg !349 + %10440 = insertelement <2 x bfloat> poison, bfloat %10384, i64 0, !dbg !349 + %10441 = insertelement <2 x bfloat> %10440, bfloat %10386, i64 1, !dbg !349 + %10442 = bitcast <2 x bfloat> %10441 to i32, !dbg !349 + %10443 = insertelement <2 x bfloat> poison, bfloat %10388, i64 0, !dbg !349 + %10444 = insertelement <2 x bfloat> %10443, bfloat %10390, i64 1, !dbg !349 + %10445 = bitcast <2 x bfloat> %10444 to i32, !dbg !349 + %10446 = insertelement <2 x bfloat> poison, bfloat %10392, i64 0, !dbg !349 + %10447 = insertelement <2 x bfloat> %10446, bfloat %10394, i64 1, !dbg !349 + %10448 = bitcast <2 x bfloat> %10447 to i32, !dbg !349 + %10449 = insertelement <2 x bfloat> poison, bfloat %10396, i64 0, !dbg !349 + %10450 = insertelement <2 x bfloat> %10449, bfloat %10398, i64 1, !dbg !349 + %10451 = bitcast <2 x bfloat> %10450 to i32, !dbg !349 + %10452 = insertelement <2 x bfloat> poison, bfloat %10400, i64 0, !dbg !349 + %10453 = insertelement <2 x bfloat> %10452, bfloat %10402, i64 1, !dbg !349 + %10454 = bitcast <2 x bfloat> %10453 to i32, !dbg !349 + %10455 = insertelement <2 x bfloat> poison, bfloat %10404, i64 0, !dbg !349 + %10456 = insertelement <2 x bfloat> %10455, bfloat %10406, i64 1, !dbg !349 + %10457 = bitcast <2 x bfloat> %10456 to i32, !dbg !349 + %10458 = insertelement <2 x bfloat> poison, bfloat %10408, i64 0, !dbg !349 + %10459 = insertelement <2 x bfloat> %10458, bfloat %10410, i64 1, !dbg !349 + %10460 = bitcast <2 x bfloat> %10459 to i32, !dbg !349 + %10461 = insertelement <2 x bfloat> poison, bfloat %10412, i64 0, !dbg !349 + %10462 = insertelement <2 x bfloat> %10461, bfloat %10414, i64 1, !dbg !349 + %10463 = bitcast <2 x bfloat> %10462 to i32, !dbg !349 + %10464 = insertelement <2 x bfloat> poison, bfloat %10416, i64 0, !dbg !349 + %10465 = insertelement <2 x bfloat> %10464, bfloat %10418, i64 1, !dbg !349 + %10466 = bitcast <2 x bfloat> %10465 to i32, !dbg !349 + tail call void @llvm.nvvm.wgmma.fence.sync.aligned(), !dbg !349 + %10467 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %.pn3471704, float %.pn3451705, float %.pn3431706, float %.pn3411707, float %.pn3391708, float %.pn3371709, float %.pn3351710, float %.pn3331711, float %.pn3311712, float %.pn3291713, float %.pn3271714, float %.pn3251715, float %.pn3231716, float %.pn3211717, float %.pn3191718, float %.pn3171719, float %.pn3151720, float %.pn3131721, float %.pn3111722, float %.pn3091723, float %.pn3071724, float %.pn3051725, float %.pn3031726, float %.pn3011727, float %.pn2991728, float %.pn2971729, float %.pn2951730, float %.pn2931731, float %.pn2911732, float %.pn2891733, float %.pn2871734, float %.pn2851735, float %.pn2831736, float %.pn2811737, float %.pn2791738, float %.pn2771739, float %.pn2751740, float %.pn2731741, float %.pn2711742, float %.pn2691743, float %.pn2671744, float %.pn2651745, float %.pn2631746, float %.pn2611747, float %.pn2591748, float %.pn2571749, float %.pn2551750, float %.pn2531751, float %.pn2511752, float %.pn2491753, float %.pn2471754, float %.pn2451755, float %.pn2431756, float %.pn2411757, float %.pn2391758, float %.pn2371759, float %.pn2351760, float %.pn2331761, float %.pn2311762, float %.pn2291763, float %.pn2271764, float %.pn2251765, float %.pn2231766, float %.pn2211767, i32 %10421, i32 %10424, i32 %10427, i32 %10430, i64 %8878, i1 true) #3, !dbg !349 + %10468 = add i32 %8874, 2048, !dbg !349 + %10469 = lshr exact i32 %10468, 4, !dbg !349 + %10470 = and i32 %10469, 16383, !dbg !349 + %10471 = zext nneg i32 %10470 to i64, !dbg !349 + %10472 = or disjoint i64 %10471, 4611686293338849280, !dbg !349 + %10473 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 0, !dbg !349 + %10474 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 1, !dbg !349 + %10475 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 2, !dbg !349 + %10476 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 3, !dbg !349 + %10477 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 4, !dbg !349 + %10478 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 5, !dbg !349 + %10479 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 6, !dbg !349 + %10480 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 7, !dbg !349 + %10481 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 8, !dbg !349 + %10482 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 9, !dbg !349 + %10483 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 10, !dbg !349 + %10484 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 11, !dbg !349 + %10485 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 12, !dbg !349 + %10486 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 13, !dbg !349 + %10487 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 14, !dbg !349 + %10488 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 15, !dbg !349 + %10489 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 16, !dbg !349 + %10490 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 17, !dbg !349 + %10491 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 18, !dbg !349 + %10492 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 19, !dbg !349 + %10493 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 20, !dbg !349 + %10494 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 21, !dbg !349 + %10495 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 22, !dbg !349 + %10496 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 23, !dbg !349 + %10497 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 24, !dbg !349 + %10498 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 25, !dbg !349 + %10499 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 26, !dbg !349 + %10500 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 27, !dbg !349 + %10501 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 28, !dbg !349 + %10502 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 29, !dbg !349 + %10503 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 30, !dbg !349 + %10504 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 31, !dbg !349 + %10505 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 32, !dbg !349 + %10506 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 33, !dbg !349 + %10507 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 34, !dbg !349 + %10508 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 35, !dbg !349 + %10509 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 36, !dbg !349 + %10510 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 37, !dbg !349 + %10511 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 38, !dbg !349 + %10512 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 39, !dbg !349 + %10513 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 40, !dbg !349 + %10514 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 41, !dbg !349 + %10515 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 42, !dbg !349 + %10516 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 43, !dbg !349 + %10517 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 44, !dbg !349 + %10518 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 45, !dbg !349 + %10519 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 46, !dbg !349 + %10520 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 47, !dbg !349 + %10521 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 48, !dbg !349 + %10522 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 49, !dbg !349 + %10523 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 50, !dbg !349 + %10524 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 51, !dbg !349 + %10525 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 52, !dbg !349 + %10526 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 53, !dbg !349 + %10527 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 54, !dbg !349 + %10528 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 55, !dbg !349 + %10529 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 56, !dbg !349 + %10530 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 57, !dbg !349 + %10531 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 58, !dbg !349 + %10532 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 59, !dbg !349 + %10533 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 60, !dbg !349 + %10534 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 61, !dbg !349 + %10535 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 62, !dbg !349 + %10536 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10467, 63, !dbg !349 + %10537 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10473, float %10474, float %10475, float %10476, float %10477, float %10478, float %10479, float %10480, float %10481, float %10482, float %10483, float %10484, float %10485, float %10486, float %10487, float %10488, float %10489, float %10490, float %10491, float %10492, float %10493, float %10494, float %10495, float %10496, float %10497, float %10498, float %10499, float %10500, float %10501, float %10502, float %10503, float %10504, float %10505, float %10506, float %10507, float %10508, float %10509, float %10510, float %10511, float %10512, float %10513, float %10514, float %10515, float %10516, float %10517, float %10518, float %10519, float %10520, float %10521, float %10522, float %10523, float %10524, float %10525, float %10526, float %10527, float %10528, float %10529, float %10530, float %10531, float %10532, float %10533, float %10534, float %10535, float %10536, i32 %10433, i32 %10436, i32 %10439, i32 %10442, i64 %10472, i1 true) #3, !dbg !349 + %10538 = add i32 %8874, 4096, !dbg !349 + %10539 = lshr exact i32 %10538, 4, !dbg !349 + %10540 = and i32 %10539, 16383, !dbg !349 + %10541 = zext nneg i32 %10540 to i64, !dbg !349 + %10542 = or disjoint i64 %10541, 4611686293338849280, !dbg !349 + %10543 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 0, !dbg !349 + %10544 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 1, !dbg !349 + %10545 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 2, !dbg !349 + %10546 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 3, !dbg !349 + %10547 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 4, !dbg !349 + %10548 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 5, !dbg !349 + %10549 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 6, !dbg !349 + %10550 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 7, !dbg !349 + %10551 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 8, !dbg !349 + %10552 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 9, !dbg !349 + %10553 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 10, !dbg !349 + %10554 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 11, !dbg !349 + %10555 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 12, !dbg !349 + %10556 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 13, !dbg !349 + %10557 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 14, !dbg !349 + %10558 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 15, !dbg !349 + %10559 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 16, !dbg !349 + %10560 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 17, !dbg !349 + %10561 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 18, !dbg !349 + %10562 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 19, !dbg !349 + %10563 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 20, !dbg !349 + %10564 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 21, !dbg !349 + %10565 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 22, !dbg !349 + %10566 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 23, !dbg !349 + %10567 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 24, !dbg !349 + %10568 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 25, !dbg !349 + %10569 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 26, !dbg !349 + %10570 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 27, !dbg !349 + %10571 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 28, !dbg !349 + %10572 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 29, !dbg !349 + %10573 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 30, !dbg !349 + %10574 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 31, !dbg !349 + %10575 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 32, !dbg !349 + %10576 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 33, !dbg !349 + %10577 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 34, !dbg !349 + %10578 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 35, !dbg !349 + %10579 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 36, !dbg !349 + %10580 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 37, !dbg !349 + %10581 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 38, !dbg !349 + %10582 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 39, !dbg !349 + %10583 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 40, !dbg !349 + %10584 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 41, !dbg !349 + %10585 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 42, !dbg !349 + %10586 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 43, !dbg !349 + %10587 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 44, !dbg !349 + %10588 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 45, !dbg !349 + %10589 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 46, !dbg !349 + %10590 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 47, !dbg !349 + %10591 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 48, !dbg !349 + %10592 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 49, !dbg !349 + %10593 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 50, !dbg !349 + %10594 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 51, !dbg !349 + %10595 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 52, !dbg !349 + %10596 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 53, !dbg !349 + %10597 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 54, !dbg !349 + %10598 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 55, !dbg !349 + %10599 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 56, !dbg !349 + %10600 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 57, !dbg !349 + %10601 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 58, !dbg !349 + %10602 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 59, !dbg !349 + %10603 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 60, !dbg !349 + %10604 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 61, !dbg !349 + %10605 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 62, !dbg !349 + %10606 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10537, 63, !dbg !349 + %10607 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10543, float %10544, float %10545, float %10546, float %10547, float %10548, float %10549, float %10550, float %10551, float %10552, float %10553, float %10554, float %10555, float %10556, float %10557, float %10558, float %10559, float %10560, float %10561, float %10562, float %10563, float %10564, float %10565, float %10566, float %10567, float %10568, float %10569, float %10570, float %10571, float %10572, float %10573, float %10574, float %10575, float %10576, float %10577, float %10578, float %10579, float %10580, float %10581, float %10582, float %10583, float %10584, float %10585, float %10586, float %10587, float %10588, float %10589, float %10590, float %10591, float %10592, float %10593, float %10594, float %10595, float %10596, float %10597, float %10598, float %10599, float %10600, float %10601, float %10602, float %10603, float %10604, float %10605, float %10606, i32 %10445, i32 %10448, i32 %10451, i32 %10454, i64 %10542, i1 true) #3, !dbg !349 + %10608 = add i32 %8874, 6144, !dbg !349 + %10609 = lshr exact i32 %10608, 4, !dbg !349 + %10610 = and i32 %10609, 16383, !dbg !349 + %10611 = zext nneg i32 %10610 to i64, !dbg !349 + %10612 = or disjoint i64 %10611, 4611686293338849280, !dbg !349 + %10613 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 0, !dbg !349 + %10614 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 1, !dbg !349 + %10615 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 2, !dbg !349 + %10616 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 3, !dbg !349 + %10617 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 4, !dbg !349 + %10618 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 5, !dbg !349 + %10619 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 6, !dbg !349 + %10620 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 7, !dbg !349 + %10621 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 8, !dbg !349 + %10622 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 9, !dbg !349 + %10623 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 10, !dbg !349 + %10624 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 11, !dbg !349 + %10625 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 12, !dbg !349 + %10626 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 13, !dbg !349 + %10627 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 14, !dbg !349 + %10628 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 15, !dbg !349 + %10629 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 16, !dbg !349 + %10630 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 17, !dbg !349 + %10631 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 18, !dbg !349 + %10632 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 19, !dbg !349 + %10633 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 20, !dbg !349 + %10634 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 21, !dbg !349 + %10635 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 22, !dbg !349 + %10636 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 23, !dbg !349 + %10637 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 24, !dbg !349 + %10638 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 25, !dbg !349 + %10639 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 26, !dbg !349 + %10640 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 27, !dbg !349 + %10641 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 28, !dbg !349 + %10642 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 29, !dbg !349 + %10643 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 30, !dbg !349 + %10644 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 31, !dbg !349 + %10645 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 32, !dbg !349 + %10646 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 33, !dbg !349 + %10647 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 34, !dbg !349 + %10648 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 35, !dbg !349 + %10649 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 36, !dbg !349 + %10650 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 37, !dbg !349 + %10651 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 38, !dbg !349 + %10652 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 39, !dbg !349 + %10653 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 40, !dbg !349 + %10654 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 41, !dbg !349 + %10655 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 42, !dbg !349 + %10656 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 43, !dbg !349 + %10657 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 44, !dbg !349 + %10658 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 45, !dbg !349 + %10659 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 46, !dbg !349 + %10660 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 47, !dbg !349 + %10661 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 48, !dbg !349 + %10662 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 49, !dbg !349 + %10663 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 50, !dbg !349 + %10664 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 51, !dbg !349 + %10665 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 52, !dbg !349 + %10666 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 53, !dbg !349 + %10667 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 54, !dbg !349 + %10668 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 55, !dbg !349 + %10669 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 56, !dbg !349 + %10670 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 57, !dbg !349 + %10671 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 58, !dbg !349 + %10672 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 59, !dbg !349 + %10673 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 60, !dbg !349 + %10674 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 61, !dbg !349 + %10675 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 62, !dbg !349 + %10676 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10607, 63, !dbg !349 + %10677 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63}, {$128,$129,$130,$131}, $132, $133, 1, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,r,r,r,r,l,b"(float %10613, float %10614, float %10615, float %10616, float %10617, float %10618, float %10619, float %10620, float %10621, float %10622, float %10623, float %10624, float %10625, float %10626, float %10627, float %10628, float %10629, float %10630, float %10631, float %10632, float %10633, float %10634, float %10635, float %10636, float %10637, float %10638, float %10639, float %10640, float %10641, float %10642, float %10643, float %10644, float %10645, float %10646, float %10647, float %10648, float %10649, float %10650, float %10651, float %10652, float %10653, float %10654, float %10655, float %10656, float %10657, float %10658, float %10659, float %10660, float %10661, float %10662, float %10663, float %10664, float %10665, float %10666, float %10667, float %10668, float %10669, float %10670, float %10671, float %10672, float %10673, float %10674, float %10675, float %10676, i32 %10457, i32 %10460, i32 %10463, i32 %10466, i64 %10612, i1 true) #3, !dbg !349 + %10678 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 0, !dbg !349 + %10679 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 1, !dbg !349 + %10680 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 2, !dbg !349 + %10681 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 3, !dbg !349 + %10682 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 4, !dbg !349 + %10683 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 5, !dbg !349 + %10684 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 6, !dbg !349 + %10685 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 7, !dbg !349 + %10686 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 8, !dbg !349 + %10687 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 9, !dbg !349 + %10688 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 10, !dbg !349 + %10689 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 11, !dbg !349 + %10690 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 12, !dbg !349 + %10691 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 13, !dbg !349 + %10692 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 14, !dbg !349 + %10693 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 15, !dbg !349 + %10694 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 16, !dbg !349 + %10695 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 17, !dbg !349 + %10696 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 18, !dbg !349 + %10697 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 19, !dbg !349 + %10698 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 20, !dbg !349 + %10699 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 21, !dbg !349 + %10700 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 22, !dbg !349 + %10701 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 23, !dbg !349 + %10702 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 24, !dbg !349 + %10703 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 25, !dbg !349 + %10704 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 26, !dbg !349 + %10705 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 27, !dbg !349 + %10706 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 28, !dbg !349 + %10707 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 29, !dbg !349 + %10708 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 30, !dbg !349 + %10709 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 31, !dbg !349 + %10710 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 32, !dbg !349 + %10711 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 33, !dbg !349 + %10712 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 34, !dbg !349 + %10713 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 35, !dbg !349 + %10714 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 36, !dbg !349 + %10715 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 37, !dbg !349 + %10716 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 38, !dbg !349 + %10717 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 39, !dbg !349 + %10718 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 40, !dbg !349 + %10719 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 41, !dbg !349 + %10720 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 42, !dbg !349 + %10721 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 43, !dbg !349 + %10722 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 44, !dbg !349 + %10723 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 45, !dbg !349 + %10724 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 46, !dbg !349 + %10725 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 47, !dbg !349 + %10726 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 48, !dbg !349 + %10727 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 49, !dbg !349 + %10728 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 50, !dbg !349 + %10729 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 51, !dbg !349 + %10730 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 52, !dbg !349 + %10731 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 53, !dbg !349 + %10732 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 54, !dbg !349 + %10733 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 55, !dbg !349 + %10734 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 56, !dbg !349 + %10735 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 57, !dbg !349 + %10736 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 58, !dbg !349 + %10737 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 59, !dbg !349 + %10738 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 60, !dbg !349 + %10739 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 61, !dbg !349 + %10740 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 62, !dbg !349 + %10741 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10677, 63, !dbg !349 + tail call void @llvm.nvvm.wgmma.commit_group.sync.aligned(), !dbg !349 + %10742 = add nuw nsw i32 %8773, 1, !dbg !332 + %10743 = lshr i32 %10742, 1, !dbg !350 + %10744 = zext nneg i32 %10743 to i64, !dbg !351 + %10745 = getelementptr i32, ptr addrspace(1) %5036, i64 %10744, !dbg !351 + %10746 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !352 + %10747 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %10745, i64 %10746, i1 %8775) #3, !dbg !352 + %10748 = add nuw nsw i32 %10743, 1, !dbg !353 + %10749 = icmp slt i32 %10748, %5040, !dbg !354 + %10750 = getelementptr i8, ptr addrspace(1) %10745, i64 4, !dbg !355 + %10751 = and i1 %8775, %10749, !dbg !332 + %10752 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09createpolicy.fractional.L2::evict_last.b64 $0, 1.0;", "=l"() #3, !dbg !356 + %10753 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$3 ld.global.L1::evict_last.L2::cache_hint.b32 { $0 }, [ $1 + 0 ], $2;", "=r,l,l,b"(ptr addrspace(1) %10750, i64 %10752, i1 %10751) #3, !dbg !356 + %10754 = and i32 %8773, 1, !dbg !357 + %10755 = sub i32 %10753, %10747, !dbg !358 + %10756 = shl i32 %10755, 7, !dbg !359 + %10757 = add i32 %10756, -64, !dbg !360 + %10758 = xor i32 %10754, 1, !dbg !361 + %10759 = mul nuw nsw i32 %10757, %10758, !dbg !361 + %10760 = shl nuw nsw i32 %10754, 6, !dbg !362 + %10761 = add i32 %10759, %10760, !dbg !363 + %10762 = shl i32 %10761, 12, !dbg !364 + %10763 = sext i32 %10762 to i64, !dbg !330 + %10764 = getelementptr bfloat, ptr addrspace(1) %.pn5391832, i64 %10763, !dbg !330 + %10765 = getelementptr bfloat, ptr addrspace(1) %.pn5231833, i64 %10763, !dbg !330 + %10766 = getelementptr bfloat, ptr addrspace(1) %.pn5071834, i64 %10763, !dbg !330 + %10767 = getelementptr bfloat, ptr addrspace(1) %.pn4911835, i64 %10763, !dbg !330 + %10768 = shl i32 %10761, 7, !dbg !365 + %10769 = sext i32 %10768 to i64, !dbg !331 + %10770 = getelementptr bfloat, ptr addrspace(1) %.pn6031836, i64 %10769, !dbg !331 + %10771 = getelementptr bfloat, ptr addrspace(1) %.pn5871837, i64 %10769, !dbg !331 + %10772 = getelementptr bfloat, ptr addrspace(1) %.pn5711838, i64 %10769, !dbg !331 + %10773 = getelementptr bfloat, ptr addrspace(1) %.pn5551839, i64 %10769, !dbg !331 + %10774 = add i32 %10761, %.pn6351840, !dbg !366 + %10775 = add i32 %10761, %.pn6331841, !dbg !366 + %10776 = add i32 %10761, %.pn6311842, !dbg !366 + %10777 = add i32 %10761, %.pn6291843, !dbg !366 + %10778 = add i32 %10761, %.pn6271844, !dbg !366 + %10779 = add i32 %10761, %.pn6251845, !dbg !366 + %10780 = add i32 %10761, %.pn6231846, !dbg !366 + %10781 = add i32 %10761, %.pn6211847, !dbg !366 + %10782 = add i32 %10761, %.pn6191848, !dbg !366 + %10783 = add i32 %10761, %.pn6171849, !dbg !366 + %10784 = add i32 %10761, %.pn6151850, !dbg !366 + %10785 = add i32 %10761, %.pn6131851, !dbg !366 + %10786 = add i32 %10761, %.pn6111852, !dbg !366 + %10787 = add i32 %10761, %.pn6091853, !dbg !366 + %10788 = add i32 %10761, %.pn6071854, !dbg !366 + %10789 = add i32 %10761, %.pn6051855, !dbg !366 + %10790 = add i32 %10761, %8769, !dbg !366 + %10791 = add i32 %10761, %8770, !dbg !366 + %10792 = add i32 %10761, %8771, !dbg !366 + %10793 = add i32 %10761, %8772, !dbg !366 + %10794 = add i32 %10761, %8765, !dbg !366 + %10795 = add i32 %10761, %8766, !dbg !366 + %10796 = add i32 %10761, %8767, !dbg !366 + %10797 = add i32 %10761, %8768, !dbg !366 + %10798 = add i32 %8762, 1, !dbg !332 + %10799 = icmp sgt i32 %10798, 1, !dbg !332 + %10800 = select i1 %10799, i32 0, i32 %10798, !dbg !332 + %10801 = add i32 %8764, 1, !dbg !332 + %10802 = icmp sgt i32 %10801, 2, !dbg !332 + %10803 = select i1 %10802, i32 0, i32 %10801, !dbg !332 + %10804 = icmp slt i32 %10790, %17, !dbg !333 + %10805 = icmp slt i32 %10791, %17, !dbg !333 + %10806 = icmp slt i32 %10792, %17, !dbg !333 + %10807 = icmp slt i32 %10793, %17, !dbg !333 + %10808 = shl i32 %10803, 13, !dbg !324 + %10809 = getelementptr bfloat, ptr addrspace(3) @global_smem, i32 %10808, !dbg !324 + %10810 = and i1 %8774, %10804, !dbg !332 + %10811 = and i1 %8774, %10805, !dbg !332 + %10812 = and i1 %8774, %10806, !dbg !332 + %10813 = and i1 %8774, %10807, !dbg !332 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !324 + %10814 = getelementptr inbounds nuw i8, ptr addrspace(3) %10809, i32 %5101, !dbg !324 + %10815 = select i1 %10810, i32 16, i32 0, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %10814, ptr addrspace(1) %10764, i32 %10815) #3, !dbg !324 + %10816 = getelementptr inbounds nuw i8, ptr addrspace(3) %10809, i32 %5104, !dbg !324 + %10817 = select i1 %10811, i32 16, i32 0, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10816, ptr addrspace(1) %10765, i32 %10817) #3, !dbg !324 + %10818 = getelementptr inbounds nuw i8, ptr addrspace(3) %10809, i32 %5107, !dbg !324 + %10819 = select i1 %10812, i32 16, i32 0, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10818, ptr addrspace(1) %10766, i32 %10819) #3, !dbg !324 + %10820 = getelementptr inbounds nuw i8, ptr addrspace(3) %10809, i32 %5110, !dbg !324 + %10821 = select i1 %10813, i32 16, i32 0, !dbg !324 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10820, ptr addrspace(1) %10767, i32 %10821) #3, !dbg !324 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !324 + %10822 = icmp slt i32 %10774, %17, !dbg !367 + %10823 = icmp slt i32 %10775, %17, !dbg !367 + %10824 = icmp slt i32 %10776, %17, !dbg !367 + %10825 = icmp slt i32 %10777, %17, !dbg !367 + %10826 = icmp slt i32 %10778, %17, !dbg !367 + %10827 = icmp slt i32 %10779, %17, !dbg !367 + %10828 = icmp slt i32 %10780, %17, !dbg !367 + %10829 = icmp slt i32 %10781, %17, !dbg !367 + %10830 = icmp slt i32 %10782, %17, !dbg !367 + %10831 = icmp slt i32 %10783, %17, !dbg !367 + %10832 = icmp slt i32 %10784, %17, !dbg !367 + %10833 = icmp slt i32 %10785, %17, !dbg !367 + %10834 = icmp slt i32 %10786, %17, !dbg !367 + %10835 = icmp slt i32 %10787, %17, !dbg !367 + %10836 = icmp slt i32 %10788, %17, !dbg !367 + %10837 = icmp slt i32 %10789, %17, !dbg !367 + %10838 = sext i32 %10774 to i64, !dbg !325 + %10839 = getelementptr float, ptr addrspace(1) %5718, i64 %10838, !dbg !325 + %10840 = sext i32 %10775 to i64, !dbg !325 + %10841 = getelementptr float, ptr addrspace(1) %5718, i64 %10840, !dbg !325 + %10842 = sext i32 %10776 to i64, !dbg !325 + %10843 = getelementptr float, ptr addrspace(1) %5718, i64 %10842, !dbg !325 + %10844 = sext i32 %10777 to i64, !dbg !325 + %10845 = getelementptr float, ptr addrspace(1) %5718, i64 %10844, !dbg !325 + %10846 = sext i32 %10778 to i64, !dbg !325 + %10847 = getelementptr float, ptr addrspace(1) %5718, i64 %10846, !dbg !325 + %10848 = sext i32 %10779 to i64, !dbg !325 + %10849 = getelementptr float, ptr addrspace(1) %5718, i64 %10848, !dbg !325 + %10850 = sext i32 %10780 to i64, !dbg !325 + %10851 = getelementptr float, ptr addrspace(1) %5718, i64 %10850, !dbg !325 + %10852 = sext i32 %10781 to i64, !dbg !325 + %10853 = getelementptr float, ptr addrspace(1) %5718, i64 %10852, !dbg !325 + %10854 = sext i32 %10782 to i64, !dbg !325 + %10855 = getelementptr float, ptr addrspace(1) %5718, i64 %10854, !dbg !325 + %10856 = sext i32 %10783 to i64, !dbg !325 + %10857 = getelementptr float, ptr addrspace(1) %5718, i64 %10856, !dbg !325 + %10858 = sext i32 %10784 to i64, !dbg !325 + %10859 = getelementptr float, ptr addrspace(1) %5718, i64 %10858, !dbg !325 + %10860 = sext i32 %10785 to i64, !dbg !325 + %10861 = getelementptr float, ptr addrspace(1) %5718, i64 %10860, !dbg !325 + %10862 = sext i32 %10786 to i64, !dbg !325 + %10863 = getelementptr float, ptr addrspace(1) %5718, i64 %10862, !dbg !325 + %10864 = sext i32 %10787 to i64, !dbg !325 + %10865 = getelementptr float, ptr addrspace(1) %5718, i64 %10864, !dbg !325 + %10866 = sext i32 %10788 to i64, !dbg !325 + %10867 = getelementptr float, ptr addrspace(1) %5718, i64 %10866, !dbg !325 + %10868 = sext i32 %10789 to i64, !dbg !325 + %10869 = getelementptr float, ptr addrspace(1) %5718, i64 %10868, !dbg !325 + %10870 = shl i32 %10800, 6, !dbg !326 + %10871 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98304), i32 %10870, !dbg !326 + %10872 = and i1 %8774, %10822, !dbg !332 + %10873 = and i1 %8774, %10823, !dbg !332 + %10874 = and i1 %8774, %10824, !dbg !332 + %10875 = and i1 %8774, %10825, !dbg !332 + %10876 = and i1 %8774, %10826, !dbg !332 + %10877 = and i1 %8774, %10827, !dbg !332 + %10878 = and i1 %8774, %10828, !dbg !332 + %10879 = and i1 %8774, %10829, !dbg !332 + %10880 = and i1 %8774, %10830, !dbg !332 + %10881 = and i1 %8774, %10831, !dbg !332 + %10882 = and i1 %8774, %10832, !dbg !332 + %10883 = and i1 %8774, %10833, !dbg !332 + %10884 = and i1 %8774, %10834, !dbg !332 + %10885 = and i1 %8774, %10835, !dbg !332 + %10886 = and i1 %8774, %10836, !dbg !332 + %10887 = and i1 %8774, %10837, !dbg !332 + %10888 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5179, !dbg !326 + %10889 = select i1 %10872, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %10888, ptr addrspace(1) %10839, i32 %10889, i1 %5178) #3, !dbg !326 + %10890 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5182, !dbg !326 + %10891 = select i1 %10873, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10890, ptr addrspace(1) %10841, i32 %10891, i1 %5178) #3, !dbg !326 + %10892 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5185, !dbg !326 + %10893 = select i1 %10874, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10892, ptr addrspace(1) %10843, i32 %10893, i1 %5178) #3, !dbg !326 + %10894 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5188, !dbg !326 + %10895 = select i1 %10875, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10894, ptr addrspace(1) %10845, i32 %10895, i1 %5178) #3, !dbg !326 + %10896 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5191, !dbg !326 + %10897 = select i1 %10876, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10896, ptr addrspace(1) %10847, i32 %10897, i1 %5178) #3, !dbg !326 + %10898 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5194, !dbg !326 + %10899 = select i1 %10877, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10898, ptr addrspace(1) %10849, i32 %10899, i1 %5178) #3, !dbg !326 + %10900 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5197, !dbg !326 + %10901 = select i1 %10878, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10900, ptr addrspace(1) %10851, i32 %10901, i1 %5178) #3, !dbg !326 + %10902 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5200, !dbg !326 + %10903 = select i1 %10879, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10902, ptr addrspace(1) %10853, i32 %10903, i1 %5178) #3, !dbg !326 + %10904 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5203, !dbg !326 + %10905 = select i1 %10880, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10904, ptr addrspace(1) %10855, i32 %10905, i1 %5178) #3, !dbg !326 + %10906 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5206, !dbg !326 + %10907 = select i1 %10881, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10906, ptr addrspace(1) %10857, i32 %10907, i1 %5178) #3, !dbg !326 + %10908 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5209, !dbg !326 + %10909 = select i1 %10882, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10908, ptr addrspace(1) %10859, i32 %10909, i1 %5178) #3, !dbg !326 + %10910 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5212, !dbg !326 + %10911 = select i1 %10883, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10910, ptr addrspace(1) %10861, i32 %10911, i1 %5178) #3, !dbg !326 + %10912 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5215, !dbg !326 + %10913 = select i1 %10884, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10912, ptr addrspace(1) %10863, i32 %10913, i1 %5178) #3, !dbg !326 + %10914 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5218, !dbg !326 + %10915 = select i1 %10885, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10914, ptr addrspace(1) %10865, i32 %10915, i1 %5178) #3, !dbg !326 + %10916 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5221, !dbg !326 + %10917 = select i1 %10886, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10916, ptr addrspace(1) %10867, i32 %10917, i1 %5178) #3, !dbg !326 + %10918 = getelementptr inbounds nuw i8, ptr addrspace(3) %10871, i32 %5224, !dbg !326 + %10919 = select i1 %10887, i32 4, i32 0, !dbg !326 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10918, ptr addrspace(1) %10869, i32 %10919, i1 %5178) #3, !dbg !326 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !326 + %10920 = icmp slt i32 %10794, %17, !dbg !368 + %10921 = icmp slt i32 %10795, %17, !dbg !368 + %10922 = icmp slt i32 %10796, %17, !dbg !368 + %10923 = icmp slt i32 %10797, %17, !dbg !368 + %10924 = getelementptr bfloat, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 49152), i32 %10808, !dbg !327 + %10925 = and i1 %8774, %10920, !dbg !332 + %10926 = and i1 %8774, %10921, !dbg !332 + %10927 = and i1 %8774, %10922, !dbg !332 + %10928 = and i1 %8774, %10923, !dbg !332 + %10929 = getelementptr inbounds nuw i8, ptr addrspace(3) %10924, i32 %5101, !dbg !327 + %10930 = select i1 %10925, i32 16, i32 0, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) %10929, ptr addrspace(1) %10770, i32 %10930) #3, !dbg !327 + %10931 = getelementptr inbounds nuw i8, ptr addrspace(3) %10924, i32 %5104, !dbg !327 + %10932 = select i1 %10926, i32 16, i32 0, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10931, ptr addrspace(1) %10771, i32 %10932) #3, !dbg !327 + %10933 = getelementptr inbounds nuw i8, ptr addrspace(3) %10924, i32 %5107, !dbg !327 + %10934 = select i1 %10927, i32 16, i32 0, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10933, ptr addrspace(1) %10772, i32 %10934) #3, !dbg !327 + %10935 = getelementptr inbounds nuw i8, ptr addrspace(3) %10924, i32 %5110, !dbg !327 + %10936 = select i1 %10928, i32 16, i32 0, !dbg !327 + tail call void asm sideeffect "cp.async.cg.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x10, $2;", "r,l,r"(ptr addrspace(3) nonnull %10935, ptr addrspace(1) %10773, i32 %10936) #3, !dbg !327 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !327 + %10937 = getelementptr float, ptr addrspace(1) %5719, i64 %10838, !dbg !328 + %10938 = getelementptr float, ptr addrspace(1) %5719, i64 %10840, !dbg !328 + %10939 = getelementptr float, ptr addrspace(1) %5719, i64 %10842, !dbg !328 + %10940 = getelementptr float, ptr addrspace(1) %5719, i64 %10844, !dbg !328 + %10941 = getelementptr float, ptr addrspace(1) %5719, i64 %10846, !dbg !328 + %10942 = getelementptr float, ptr addrspace(1) %5719, i64 %10848, !dbg !328 + %10943 = getelementptr float, ptr addrspace(1) %5719, i64 %10850, !dbg !328 + %10944 = getelementptr float, ptr addrspace(1) %5719, i64 %10852, !dbg !328 + %10945 = getelementptr float, ptr addrspace(1) %5719, i64 %10854, !dbg !328 + %10946 = getelementptr float, ptr addrspace(1) %5719, i64 %10856, !dbg !328 + %10947 = getelementptr float, ptr addrspace(1) %5719, i64 %10858, !dbg !328 + %10948 = getelementptr float, ptr addrspace(1) %5719, i64 %10860, !dbg !328 + %10949 = getelementptr float, ptr addrspace(1) %5719, i64 %10862, !dbg !328 + %10950 = getelementptr float, ptr addrspace(1) %5719, i64 %10864, !dbg !328 + %10951 = getelementptr float, ptr addrspace(1) %5719, i64 %10866, !dbg !328 + %10952 = getelementptr float, ptr addrspace(1) %5719, i64 %10868, !dbg !328 + %10953 = getelementptr float, ptr addrspace(3) getelementptr (i8, ptr addrspace(3) @global_smem, i32 98816), i32 %10870, !dbg !329 + %10954 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5179, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) %10954, ptr addrspace(1) %10937, i32 %10889, i1 %5178) #3, !dbg !329 + %10955 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5182, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10955, ptr addrspace(1) %10938, i32 %10891, i1 %5178) #3, !dbg !329 + %10956 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5185, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10956, ptr addrspace(1) %10939, i32 %10893, i1 %5178) #3, !dbg !329 + %10957 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5188, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10957, ptr addrspace(1) %10940, i32 %10895, i1 %5178) #3, !dbg !329 + %10958 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5191, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10958, ptr addrspace(1) %10941, i32 %10897, i1 %5178) #3, !dbg !329 + %10959 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5194, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10959, ptr addrspace(1) %10942, i32 %10899, i1 %5178) #3, !dbg !329 + %10960 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5197, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10960, ptr addrspace(1) %10943, i32 %10901, i1 %5178) #3, !dbg !329 + %10961 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5200, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10961, ptr addrspace(1) %10944, i32 %10903, i1 %5178) #3, !dbg !329 + %10962 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5203, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10962, ptr addrspace(1) %10945, i32 %10905, i1 %5178) #3, !dbg !329 + %10963 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5206, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10963, ptr addrspace(1) %10946, i32 %10907, i1 %5178) #3, !dbg !329 + %10964 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5209, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10964, ptr addrspace(1) %10947, i32 %10909, i1 %5178) #3, !dbg !329 + %10965 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5212, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10965, ptr addrspace(1) %10948, i32 %10911, i1 %5178) #3, !dbg !329 + %10966 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5215, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10966, ptr addrspace(1) %10949, i32 %10913, i1 %5178) #3, !dbg !329 + %10967 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5218, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10967, ptr addrspace(1) %10950, i32 %10915, i1 %5178) #3, !dbg !329 + %10968 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5221, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10968, ptr addrspace(1) %10951, i32 %10917, i1 %5178) #3, !dbg !329 + %10969 = getelementptr inbounds nuw i8, ptr addrspace(3) %10953, i32 %5224, !dbg !329 + tail call void asm sideeffect "@$3 cp.async.ca.shared.global [ $0 + 0 ], [ $1 + 0 ], 0x4, $2;", "r,l,r,b"(ptr addrspace(3) nonnull %10969, ptr addrspace(1) %10952, i32 %10919, i1 %5178) #3, !dbg !329 + tail call void @llvm.nvvm.cp.async.commit.group(), !dbg !329 + %exitcond2268.not = icmp eq i32 %10742, %smax2267, !dbg !332 + br i1 %exitcond2268.not, label %._crit_edge1874, label %.lr.ph1873, !dbg !332 + +._crit_edge1874: ; preds = %__nv_exp2f.exit1321, %._crit_edge1701 + %.pn347.lcssa = phi float [ %8609, %._crit_edge1701 ], [ %10678, %__nv_exp2f.exit1321 ] + %.pn345.lcssa = phi float [ %8610, %._crit_edge1701 ], [ %10679, %__nv_exp2f.exit1321 ] + %.pn343.lcssa = phi float [ %8611, %._crit_edge1701 ], [ %10680, %__nv_exp2f.exit1321 ] + %.pn341.lcssa = phi float [ %8612, %._crit_edge1701 ], [ %10681, %__nv_exp2f.exit1321 ] + %.pn339.lcssa = phi float [ %8613, %._crit_edge1701 ], [ %10682, %__nv_exp2f.exit1321 ] + %.pn337.lcssa = phi float [ %8614, %._crit_edge1701 ], [ %10683, %__nv_exp2f.exit1321 ] + %.pn335.lcssa = phi float [ %8615, %._crit_edge1701 ], [ %10684, %__nv_exp2f.exit1321 ] + %.pn333.lcssa = phi float [ %8616, %._crit_edge1701 ], [ %10685, %__nv_exp2f.exit1321 ] + %.pn331.lcssa = phi float [ %8617, %._crit_edge1701 ], [ %10686, %__nv_exp2f.exit1321 ] + %.pn329.lcssa = phi float [ %8618, %._crit_edge1701 ], [ %10687, %__nv_exp2f.exit1321 ] + %.pn327.lcssa = phi float [ %8619, %._crit_edge1701 ], [ %10688, %__nv_exp2f.exit1321 ] + %.pn325.lcssa = phi float [ %8620, %._crit_edge1701 ], [ %10689, %__nv_exp2f.exit1321 ] + %.pn323.lcssa = phi float [ %8621, %._crit_edge1701 ], [ %10690, %__nv_exp2f.exit1321 ] + %.pn321.lcssa = phi float [ %8622, %._crit_edge1701 ], [ %10691, %__nv_exp2f.exit1321 ] + %.pn319.lcssa = phi float [ %8623, %._crit_edge1701 ], [ %10692, %__nv_exp2f.exit1321 ] + %.pn317.lcssa = phi float [ %8624, %._crit_edge1701 ], [ %10693, %__nv_exp2f.exit1321 ] + %.pn315.lcssa = phi float [ %8625, %._crit_edge1701 ], [ %10694, %__nv_exp2f.exit1321 ] + %.pn313.lcssa = phi float [ %8626, %._crit_edge1701 ], [ %10695, %__nv_exp2f.exit1321 ] + %.pn311.lcssa = phi float [ %8627, %._crit_edge1701 ], [ %10696, %__nv_exp2f.exit1321 ] + %.pn309.lcssa = phi float [ %8628, %._crit_edge1701 ], [ %10697, %__nv_exp2f.exit1321 ] + %.pn307.lcssa = phi float [ %8629, %._crit_edge1701 ], [ %10698, %__nv_exp2f.exit1321 ] + %.pn305.lcssa = phi float [ %8630, %._crit_edge1701 ], [ %10699, %__nv_exp2f.exit1321 ] + %.pn303.lcssa = phi float [ %8631, %._crit_edge1701 ], [ %10700, %__nv_exp2f.exit1321 ] + %.pn301.lcssa = phi float [ %8632, %._crit_edge1701 ], [ %10701, %__nv_exp2f.exit1321 ] + %.pn299.lcssa = phi float [ %8633, %._crit_edge1701 ], [ %10702, %__nv_exp2f.exit1321 ] + %.pn297.lcssa = phi float [ %8634, %._crit_edge1701 ], [ %10703, %__nv_exp2f.exit1321 ] + %.pn295.lcssa = phi float [ %8635, %._crit_edge1701 ], [ %10704, %__nv_exp2f.exit1321 ] + %.pn293.lcssa = phi float [ %8636, %._crit_edge1701 ], [ %10705, %__nv_exp2f.exit1321 ] + %.pn291.lcssa = phi float [ %8637, %._crit_edge1701 ], [ %10706, %__nv_exp2f.exit1321 ] + %.pn289.lcssa = phi float [ %8638, %._crit_edge1701 ], [ %10707, %__nv_exp2f.exit1321 ] + %.pn287.lcssa = phi float [ %8639, %._crit_edge1701 ], [ %10708, %__nv_exp2f.exit1321 ] + %.pn285.lcssa = phi float [ %8640, %._crit_edge1701 ], [ %10709, %__nv_exp2f.exit1321 ] + %.pn283.lcssa = phi float [ %8641, %._crit_edge1701 ], [ %10710, %__nv_exp2f.exit1321 ] + %.pn281.lcssa = phi float [ %8642, %._crit_edge1701 ], [ %10711, %__nv_exp2f.exit1321 ] + %.pn279.lcssa = phi float [ %8643, %._crit_edge1701 ], [ %10712, %__nv_exp2f.exit1321 ] + %.pn277.lcssa = phi float [ %8644, %._crit_edge1701 ], [ %10713, %__nv_exp2f.exit1321 ] + %.pn275.lcssa = phi float [ %8645, %._crit_edge1701 ], [ %10714, %__nv_exp2f.exit1321 ] + %.pn273.lcssa = phi float [ %8646, %._crit_edge1701 ], [ %10715, %__nv_exp2f.exit1321 ] + %.pn271.lcssa = phi float [ %8647, %._crit_edge1701 ], [ %10716, %__nv_exp2f.exit1321 ] + %.pn269.lcssa = phi float [ %8648, %._crit_edge1701 ], [ %10717, %__nv_exp2f.exit1321 ] + %.pn267.lcssa = phi float [ %8649, %._crit_edge1701 ], [ %10718, %__nv_exp2f.exit1321 ] + %.pn265.lcssa = phi float [ %8650, %._crit_edge1701 ], [ %10719, %__nv_exp2f.exit1321 ] + %.pn263.lcssa = phi float [ %8651, %._crit_edge1701 ], [ %10720, %__nv_exp2f.exit1321 ] + %.pn261.lcssa = phi float [ %8652, %._crit_edge1701 ], [ %10721, %__nv_exp2f.exit1321 ] + %.pn259.lcssa = phi float [ %8653, %._crit_edge1701 ], [ %10722, %__nv_exp2f.exit1321 ] + %.pn257.lcssa = phi float [ %8654, %._crit_edge1701 ], [ %10723, %__nv_exp2f.exit1321 ] + %.pn255.lcssa = phi float [ %8655, %._crit_edge1701 ], [ %10724, %__nv_exp2f.exit1321 ] + %.pn253.lcssa = phi float [ %8656, %._crit_edge1701 ], [ %10725, %__nv_exp2f.exit1321 ] + %.pn251.lcssa = phi float [ %8657, %._crit_edge1701 ], [ %10726, %__nv_exp2f.exit1321 ] + %.pn249.lcssa = phi float [ %8658, %._crit_edge1701 ], [ %10727, %__nv_exp2f.exit1321 ] + %.pn247.lcssa = phi float [ %8659, %._crit_edge1701 ], [ %10728, %__nv_exp2f.exit1321 ] + %.pn245.lcssa = phi float [ %8660, %._crit_edge1701 ], [ %10729, %__nv_exp2f.exit1321 ] + %.pn243.lcssa = phi float [ %8661, %._crit_edge1701 ], [ %10730, %__nv_exp2f.exit1321 ] + %.pn241.lcssa = phi float [ %8662, %._crit_edge1701 ], [ %10731, %__nv_exp2f.exit1321 ] + %.pn239.lcssa = phi float [ %8663, %._crit_edge1701 ], [ %10732, %__nv_exp2f.exit1321 ] + %.pn237.lcssa = phi float [ %8664, %._crit_edge1701 ], [ %10733, %__nv_exp2f.exit1321 ] + %.pn235.lcssa = phi float [ %8665, %._crit_edge1701 ], [ %10734, %__nv_exp2f.exit1321 ] + %.pn233.lcssa = phi float [ %8666, %._crit_edge1701 ], [ %10735, %__nv_exp2f.exit1321 ] + %.pn231.lcssa = phi float [ %8667, %._crit_edge1701 ], [ %10736, %__nv_exp2f.exit1321 ] + %.pn229.lcssa = phi float [ %8668, %._crit_edge1701 ], [ %10737, %__nv_exp2f.exit1321 ] + %.pn227.lcssa = phi float [ %8669, %._crit_edge1701 ], [ %10738, %__nv_exp2f.exit1321 ] + %.pn225.lcssa = phi float [ %8670, %._crit_edge1701 ], [ %10739, %__nv_exp2f.exit1321 ] + %.pn223.lcssa = phi float [ %8671, %._crit_edge1701 ], [ %10740, %__nv_exp2f.exit1321 ] + %.pn221.lcssa = phi float [ %8672, %._crit_edge1701 ], [ %10741, %__nv_exp2f.exit1321 ] + %.pn475.lcssa = phi float [ %8545, %._crit_edge1701 ], [ %9822, %__nv_exp2f.exit1321 ] + %.pn473.lcssa = phi float [ %8546, %._crit_edge1701 ], [ %9823, %__nv_exp2f.exit1321 ] + %.pn471.lcssa = phi float [ %8547, %._crit_edge1701 ], [ %9824, %__nv_exp2f.exit1321 ] + %.pn469.lcssa = phi float [ %8548, %._crit_edge1701 ], [ %9825, %__nv_exp2f.exit1321 ] + %.pn467.lcssa = phi float [ %8549, %._crit_edge1701 ], [ %9826, %__nv_exp2f.exit1321 ] + %.pn465.lcssa = phi float [ %8550, %._crit_edge1701 ], [ %9827, %__nv_exp2f.exit1321 ] + %.pn463.lcssa = phi float [ %8551, %._crit_edge1701 ], [ %9828, %__nv_exp2f.exit1321 ] + %.pn461.lcssa = phi float [ %8552, %._crit_edge1701 ], [ %9829, %__nv_exp2f.exit1321 ] + %.pn459.lcssa = phi float [ %8553, %._crit_edge1701 ], [ %9830, %__nv_exp2f.exit1321 ] + %.pn457.lcssa = phi float [ %8554, %._crit_edge1701 ], [ %9831, %__nv_exp2f.exit1321 ] + %.pn455.lcssa = phi float [ %8555, %._crit_edge1701 ], [ %9832, %__nv_exp2f.exit1321 ] + %.pn453.lcssa = phi float [ %8556, %._crit_edge1701 ], [ %9833, %__nv_exp2f.exit1321 ] + %.pn451.lcssa = phi float [ %8557, %._crit_edge1701 ], [ %9834, %__nv_exp2f.exit1321 ] + %.pn449.lcssa = phi float [ %8558, %._crit_edge1701 ], [ %9835, %__nv_exp2f.exit1321 ] + %.pn447.lcssa = phi float [ %8559, %._crit_edge1701 ], [ %9836, %__nv_exp2f.exit1321 ] + %.pn445.lcssa = phi float [ %8560, %._crit_edge1701 ], [ %9837, %__nv_exp2f.exit1321 ] + %.pn443.lcssa = phi float [ %8561, %._crit_edge1701 ], [ %9838, %__nv_exp2f.exit1321 ] + %.pn441.lcssa = phi float [ %8562, %._crit_edge1701 ], [ %9839, %__nv_exp2f.exit1321 ] + %.pn439.lcssa = phi float [ %8563, %._crit_edge1701 ], [ %9840, %__nv_exp2f.exit1321 ] + %.pn437.lcssa = phi float [ %8564, %._crit_edge1701 ], [ %9841, %__nv_exp2f.exit1321 ] + %.pn435.lcssa = phi float [ %8565, %._crit_edge1701 ], [ %9842, %__nv_exp2f.exit1321 ] + %.pn433.lcssa = phi float [ %8566, %._crit_edge1701 ], [ %9843, %__nv_exp2f.exit1321 ] + %.pn431.lcssa = phi float [ %8567, %._crit_edge1701 ], [ %9844, %__nv_exp2f.exit1321 ] + %.pn429.lcssa = phi float [ %8568, %._crit_edge1701 ], [ %9845, %__nv_exp2f.exit1321 ] + %.pn427.lcssa = phi float [ %8569, %._crit_edge1701 ], [ %9846, %__nv_exp2f.exit1321 ] + %.pn425.lcssa = phi float [ %8570, %._crit_edge1701 ], [ %9847, %__nv_exp2f.exit1321 ] + %.pn423.lcssa = phi float [ %8571, %._crit_edge1701 ], [ %9848, %__nv_exp2f.exit1321 ] + %.pn421.lcssa = phi float [ %8572, %._crit_edge1701 ], [ %9849, %__nv_exp2f.exit1321 ] + %.pn419.lcssa = phi float [ %8573, %._crit_edge1701 ], [ %9850, %__nv_exp2f.exit1321 ] + %.pn417.lcssa = phi float [ %8574, %._crit_edge1701 ], [ %9851, %__nv_exp2f.exit1321 ] + %.pn415.lcssa = phi float [ %8575, %._crit_edge1701 ], [ %9852, %__nv_exp2f.exit1321 ] + %.pn413.lcssa = phi float [ %8576, %._crit_edge1701 ], [ %9853, %__nv_exp2f.exit1321 ] + %.pn411.lcssa = phi float [ %8577, %._crit_edge1701 ], [ %9854, %__nv_exp2f.exit1321 ] + %.pn409.lcssa = phi float [ %8578, %._crit_edge1701 ], [ %9855, %__nv_exp2f.exit1321 ] + %.pn407.lcssa = phi float [ %8579, %._crit_edge1701 ], [ %9856, %__nv_exp2f.exit1321 ] + %.pn405.lcssa = phi float [ %8580, %._crit_edge1701 ], [ %9857, %__nv_exp2f.exit1321 ] + %.pn403.lcssa = phi float [ %8581, %._crit_edge1701 ], [ %9858, %__nv_exp2f.exit1321 ] + %.pn401.lcssa = phi float [ %8582, %._crit_edge1701 ], [ %9859, %__nv_exp2f.exit1321 ] + %.pn399.lcssa = phi float [ %8583, %._crit_edge1701 ], [ %9860, %__nv_exp2f.exit1321 ] + %.pn397.lcssa = phi float [ %8584, %._crit_edge1701 ], [ %9861, %__nv_exp2f.exit1321 ] + %.pn395.lcssa = phi float [ %8585, %._crit_edge1701 ], [ %9862, %__nv_exp2f.exit1321 ] + %.pn393.lcssa = phi float [ %8586, %._crit_edge1701 ], [ %9863, %__nv_exp2f.exit1321 ] + %.pn391.lcssa = phi float [ %8587, %._crit_edge1701 ], [ %9864, %__nv_exp2f.exit1321 ] + %.pn389.lcssa = phi float [ %8588, %._crit_edge1701 ], [ %9865, %__nv_exp2f.exit1321 ] + %.pn387.lcssa = phi float [ %8589, %._crit_edge1701 ], [ %9866, %__nv_exp2f.exit1321 ] + %.pn385.lcssa = phi float [ %8590, %._crit_edge1701 ], [ %9867, %__nv_exp2f.exit1321 ] + %.pn383.lcssa = phi float [ %8591, %._crit_edge1701 ], [ %9868, %__nv_exp2f.exit1321 ] + %.pn381.lcssa = phi float [ %8592, %._crit_edge1701 ], [ %9869, %__nv_exp2f.exit1321 ] + %.pn379.lcssa = phi float [ %8593, %._crit_edge1701 ], [ %9870, %__nv_exp2f.exit1321 ] + %.pn377.lcssa = phi float [ %8594, %._crit_edge1701 ], [ %9871, %__nv_exp2f.exit1321 ] + %.pn375.lcssa = phi float [ %8595, %._crit_edge1701 ], [ %9872, %__nv_exp2f.exit1321 ] + %.pn373.lcssa = phi float [ %8596, %._crit_edge1701 ], [ %9873, %__nv_exp2f.exit1321 ] + %.pn371.lcssa = phi float [ %8597, %._crit_edge1701 ], [ %9874, %__nv_exp2f.exit1321 ] + %.pn369.lcssa = phi float [ %8598, %._crit_edge1701 ], [ %9875, %__nv_exp2f.exit1321 ] + %.pn367.lcssa = phi float [ %8599, %._crit_edge1701 ], [ %9876, %__nv_exp2f.exit1321 ] + %.pn365.lcssa = phi float [ %8600, %._crit_edge1701 ], [ %9877, %__nv_exp2f.exit1321 ] + %.pn363.lcssa = phi float [ %8601, %._crit_edge1701 ], [ %9878, %__nv_exp2f.exit1321 ] + %.pn361.lcssa = phi float [ %8602, %._crit_edge1701 ], [ %9879, %__nv_exp2f.exit1321 ] + %.pn359.lcssa = phi float [ %8603, %._crit_edge1701 ], [ %9880, %__nv_exp2f.exit1321 ] + %.pn357.lcssa = phi float [ %8604, %._crit_edge1701 ], [ %9881, %__nv_exp2f.exit1321 ] + %.pn355.lcssa = phi float [ %8605, %._crit_edge1701 ], [ %9882, %__nv_exp2f.exit1321 ] + %.pn353.lcssa = phi float [ %8606, %._crit_edge1701 ], [ %9883, %__nv_exp2f.exit1321 ] + %.pn351.lcssa = phi float [ %8607, %._crit_edge1701 ], [ %9884, %__nv_exp2f.exit1321 ] + %.pn349.lcssa = phi float [ %8608, %._crit_edge1701 ], [ %9885, %__nv_exp2f.exit1321 ] + %10970 = tail call { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } asm sideeffect "// wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127\0A\09wgmma.wait_group.sync.aligned 0;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127"(float %.pn475.lcssa, float %.pn473.lcssa, float %.pn471.lcssa, float %.pn469.lcssa, float %.pn467.lcssa, float %.pn465.lcssa, float %.pn463.lcssa, float %.pn461.lcssa, float %.pn459.lcssa, float %.pn457.lcssa, float %.pn455.lcssa, float %.pn453.lcssa, float %.pn451.lcssa, float %.pn449.lcssa, float %.pn447.lcssa, float %.pn445.lcssa, float %.pn443.lcssa, float %.pn441.lcssa, float %.pn439.lcssa, float %.pn437.lcssa, float %.pn435.lcssa, float %.pn433.lcssa, float %.pn431.lcssa, float %.pn429.lcssa, float %.pn427.lcssa, float %.pn425.lcssa, float %.pn423.lcssa, float %.pn421.lcssa, float %.pn419.lcssa, float %.pn417.lcssa, float %.pn415.lcssa, float %.pn413.lcssa, float %.pn411.lcssa, float %.pn409.lcssa, float %.pn407.lcssa, float %.pn405.lcssa, float %.pn403.lcssa, float %.pn401.lcssa, float %.pn399.lcssa, float %.pn397.lcssa, float %.pn395.lcssa, float %.pn393.lcssa, float %.pn391.lcssa, float %.pn389.lcssa, float %.pn387.lcssa, float %.pn385.lcssa, float %.pn383.lcssa, float %.pn381.lcssa, float %.pn379.lcssa, float %.pn377.lcssa, float %.pn375.lcssa, float %.pn373.lcssa, float %.pn371.lcssa, float %.pn369.lcssa, float %.pn367.lcssa, float %.pn365.lcssa, float %.pn363.lcssa, float %.pn361.lcssa, float %.pn359.lcssa, float %.pn357.lcssa, float %.pn355.lcssa, float %.pn353.lcssa, float %.pn351.lcssa, float %.pn349.lcssa, float %.pn347.lcssa, float %.pn345.lcssa, float %.pn343.lcssa, float %.pn341.lcssa, float %.pn339.lcssa, float %.pn337.lcssa, float %.pn335.lcssa, float %.pn333.lcssa, float %.pn331.lcssa, float %.pn329.lcssa, float %.pn327.lcssa, float %.pn325.lcssa, float %.pn323.lcssa, float %.pn321.lcssa, float %.pn319.lcssa, float %.pn317.lcssa, float %.pn315.lcssa, float %.pn313.lcssa, float %.pn311.lcssa, float %.pn309.lcssa, float %.pn307.lcssa, float %.pn305.lcssa, float %.pn303.lcssa, float %.pn301.lcssa, float %.pn299.lcssa, float %.pn297.lcssa, float %.pn295.lcssa, float %.pn293.lcssa, float %.pn291.lcssa, float %.pn289.lcssa, float %.pn287.lcssa, float %.pn285.lcssa, float %.pn283.lcssa, float %.pn281.lcssa, float %.pn279.lcssa, float %.pn277.lcssa, float %.pn275.lcssa, float %.pn273.lcssa, float %.pn271.lcssa, float %.pn269.lcssa, float %.pn267.lcssa, float %.pn265.lcssa, float %.pn263.lcssa, float %.pn261.lcssa, float %.pn259.lcssa, float %.pn257.lcssa, float %.pn255.lcssa, float %.pn253.lcssa, float %.pn251.lcssa, float %.pn249.lcssa, float %.pn247.lcssa, float %.pn245.lcssa, float %.pn243.lcssa, float %.pn241.lcssa, float %.pn239.lcssa, float %.pn237.lcssa, float %.pn235.lcssa, float %.pn233.lcssa, float %.pn231.lcssa, float %.pn229.lcssa, float %.pn227.lcssa, float %.pn225.lcssa, float %.pn223.lcssa, float %.pn221.lcssa) #3, !dbg !332 + %10971 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 0, !dbg !332 + %10972 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 1, !dbg !332 + %10973 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 2, !dbg !332 + %10974 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 3, !dbg !332 + %10975 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 4, !dbg !332 + %10976 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 5, !dbg !332 + %10977 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 6, !dbg !332 + %10978 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 7, !dbg !332 + %10979 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 8, !dbg !332 + %10980 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 9, !dbg !332 + %10981 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 10, !dbg !332 + %10982 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 11, !dbg !332 + %10983 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 12, !dbg !332 + %10984 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 13, !dbg !332 + %10985 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 14, !dbg !332 + %10986 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 15, !dbg !332 + %10987 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 16, !dbg !332 + %10988 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 17, !dbg !332 + %10989 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 18, !dbg !332 + %10990 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 19, !dbg !332 + %10991 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 20, !dbg !332 + %10992 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 21, !dbg !332 + %10993 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 22, !dbg !332 + %10994 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 23, !dbg !332 + %10995 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 24, !dbg !332 + %10996 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 25, !dbg !332 + %10997 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 26, !dbg !332 + %10998 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 27, !dbg !332 + %10999 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 28, !dbg !332 + %11000 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 29, !dbg !332 + %11001 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 30, !dbg !332 + %11002 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 31, !dbg !332 + %11003 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 32, !dbg !332 + %11004 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 33, !dbg !332 + %11005 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 34, !dbg !332 + %11006 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 35, !dbg !332 + %11007 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 36, !dbg !332 + %11008 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 37, !dbg !332 + %11009 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 38, !dbg !332 + %11010 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 39, !dbg !332 + %11011 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 40, !dbg !332 + %11012 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 41, !dbg !332 + %11013 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 42, !dbg !332 + %11014 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 43, !dbg !332 + %11015 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 44, !dbg !332 + %11016 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 45, !dbg !332 + %11017 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 46, !dbg !332 + %11018 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 47, !dbg !332 + %11019 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 48, !dbg !332 + %11020 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 49, !dbg !332 + %11021 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 50, !dbg !332 + %11022 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 51, !dbg !332 + %11023 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 52, !dbg !332 + %11024 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 53, !dbg !332 + %11025 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 54, !dbg !332 + %11026 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 55, !dbg !332 + %11027 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 56, !dbg !332 + %11028 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 57, !dbg !332 + %11029 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 58, !dbg !332 + %11030 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 59, !dbg !332 + %11031 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 60, !dbg !332 + %11032 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 61, !dbg !332 + %11033 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 62, !dbg !332 + %11034 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 63, !dbg !332 + %11035 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 64, !dbg !332 + %11036 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 65, !dbg !332 + %11037 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 66, !dbg !332 + %11038 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 67, !dbg !332 + %11039 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 68, !dbg !332 + %11040 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 69, !dbg !332 + %11041 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 70, !dbg !332 + %11042 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 71, !dbg !332 + %11043 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 72, !dbg !332 + %11044 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 73, !dbg !332 + %11045 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 74, !dbg !332 + %11046 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 75, !dbg !332 + %11047 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 76, !dbg !332 + %11048 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 77, !dbg !332 + %11049 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 78, !dbg !332 + %11050 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 79, !dbg !332 + %11051 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 80, !dbg !332 + %11052 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 81, !dbg !332 + %11053 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 82, !dbg !332 + %11054 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 83, !dbg !332 + %11055 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 84, !dbg !332 + %11056 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 85, !dbg !332 + %11057 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 86, !dbg !332 + %11058 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 87, !dbg !332 + %11059 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 88, !dbg !332 + %11060 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 89, !dbg !332 + %11061 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 90, !dbg !332 + %11062 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 91, !dbg !332 + %11063 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 92, !dbg !332 + %11064 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 93, !dbg !332 + %11065 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 94, !dbg !332 + %11066 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 95, !dbg !332 + %11067 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 96, !dbg !332 + %11068 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 97, !dbg !332 + %11069 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 98, !dbg !332 + %11070 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 99, !dbg !332 + %11071 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 100, !dbg !332 + %11072 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 101, !dbg !332 + %11073 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 102, !dbg !332 + %11074 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 103, !dbg !332 + %11075 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 104, !dbg !332 + %11076 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 105, !dbg !332 + %11077 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 106, !dbg !332 + %11078 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 107, !dbg !332 + %11079 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 108, !dbg !332 + %11080 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 109, !dbg !332 + %11081 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 110, !dbg !332 + %11082 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 111, !dbg !332 + %11083 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 112, !dbg !332 + %11084 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 113, !dbg !332 + %11085 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 114, !dbg !332 + %11086 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 115, !dbg !332 + %11087 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 116, !dbg !332 + %11088 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 117, !dbg !332 + %11089 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 118, !dbg !332 + %11090 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 119, !dbg !332 + %11091 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 120, !dbg !332 + %11092 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 121, !dbg !332 + %11093 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 122, !dbg !332 + %11094 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 123, !dbg !332 + %11095 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 124, !dbg !332 + %11096 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 125, !dbg !332 + %11097 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 126, !dbg !332 + %11098 = extractvalue { float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float } %10970, 127, !dbg !332 + tail call void @llvm.nvvm.cp.async.wait.group(i32 0), !dbg !332 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !332 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1, !dbg !249 + %exitcond2269.not = icmp eq i64 %indvars.iv.next, 4, !dbg !249 + br i1 %exitcond2269.not, label %11099, label %5575, !dbg !249 + +11099: ; preds = %._crit_edge1874 + %11100 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4750, !dbg !369 + %11101 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4752, !dbg !369 + %11102 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4754, !dbg !369 + %11103 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4756, !dbg !369 + %11104 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4758, !dbg !369 + %11105 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4760, !dbg !369 + %11106 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4762, !dbg !369 + %11107 = getelementptr bfloat, ptr addrspace(1) %43, i64 %4764, !dbg !369 + %11108 = getelementptr bfloat, ptr addrspace(1) %11100, i64 %4768, !dbg !370 + %11109 = getelementptr bfloat, ptr addrspace(1) %11101, i64 %4768, !dbg !370 + %11110 = getelementptr bfloat, ptr addrspace(1) %11102, i64 %4768, !dbg !370 + %11111 = getelementptr bfloat, ptr addrspace(1) %11103, i64 %4768, !dbg !370 + %11112 = getelementptr bfloat, ptr addrspace(1) %11104, i64 %4768, !dbg !370 + %11113 = getelementptr bfloat, ptr addrspace(1) %11105, i64 %4768, !dbg !370 + %11114 = getelementptr bfloat, ptr addrspace(1) %11106, i64 %4768, !dbg !370 + %11115 = getelementptr bfloat, ptr addrspace(1) %11107, i64 %4768, !dbg !370 + %11116 = insertelement <2 x float> poison, float %10971, i64 0, !dbg !371 + %11117 = insertelement <2 x float> %11116, float %10972, i64 1, !dbg !371 + %11118 = fptrunc <2 x float> %11117 to <2 x bfloat>, !dbg !371 + %11119 = insertelement <2 x float> poison, float %10973, i64 0, !dbg !371 + %11120 = insertelement <2 x float> %11119, float %10974, i64 1, !dbg !371 + %11121 = fptrunc <2 x float> %11120 to <2 x bfloat>, !dbg !371 + %11122 = insertelement <2 x float> poison, float %10975, i64 0, !dbg !371 + %11123 = insertelement <2 x float> %11122, float %10976, i64 1, !dbg !371 + %11124 = fptrunc <2 x float> %11123 to <2 x bfloat>, !dbg !371 + %11125 = insertelement <2 x float> poison, float %10977, i64 0, !dbg !371 + %11126 = insertelement <2 x float> %11125, float %10978, i64 1, !dbg !371 + %11127 = fptrunc <2 x float> %11126 to <2 x bfloat>, !dbg !371 + %11128 = insertelement <2 x float> poison, float %10979, i64 0, !dbg !371 + %11129 = insertelement <2 x float> %11128, float %10980, i64 1, !dbg !371 + %11130 = fptrunc <2 x float> %11129 to <2 x bfloat>, !dbg !371 + %11131 = insertelement <2 x float> poison, float %10981, i64 0, !dbg !371 + %11132 = insertelement <2 x float> %11131, float %10982, i64 1, !dbg !371 + %11133 = fptrunc <2 x float> %11132 to <2 x bfloat>, !dbg !371 + %11134 = insertelement <2 x float> poison, float %10983, i64 0, !dbg !371 + %11135 = insertelement <2 x float> %11134, float %10984, i64 1, !dbg !371 + %11136 = fptrunc <2 x float> %11135 to <2 x bfloat>, !dbg !371 + %11137 = insertelement <2 x float> poison, float %10985, i64 0, !dbg !371 + %11138 = insertelement <2 x float> %11137, float %10986, i64 1, !dbg !371 + %11139 = fptrunc <2 x float> %11138 to <2 x bfloat>, !dbg !371 + %11140 = insertelement <2 x float> poison, float %10987, i64 0, !dbg !371 + %11141 = insertelement <2 x float> %11140, float %10988, i64 1, !dbg !371 + %11142 = fptrunc <2 x float> %11141 to <2 x bfloat>, !dbg !371 + %11143 = insertelement <2 x float> poison, float %10989, i64 0, !dbg !371 + %11144 = insertelement <2 x float> %11143, float %10990, i64 1, !dbg !371 + %11145 = fptrunc <2 x float> %11144 to <2 x bfloat>, !dbg !371 + %11146 = insertelement <2 x float> poison, float %10991, i64 0, !dbg !371 + %11147 = insertelement <2 x float> %11146, float %10992, i64 1, !dbg !371 + %11148 = fptrunc <2 x float> %11147 to <2 x bfloat>, !dbg !371 + %11149 = insertelement <2 x float> poison, float %10993, i64 0, !dbg !371 + %11150 = insertelement <2 x float> %11149, float %10994, i64 1, !dbg !371 + %11151 = fptrunc <2 x float> %11150 to <2 x bfloat>, !dbg !371 + %11152 = insertelement <2 x float> poison, float %10995, i64 0, !dbg !371 + %11153 = insertelement <2 x float> %11152, float %10996, i64 1, !dbg !371 + %11154 = fptrunc <2 x float> %11153 to <2 x bfloat>, !dbg !371 + %11155 = insertelement <2 x float> poison, float %10997, i64 0, !dbg !371 + %11156 = insertelement <2 x float> %11155, float %10998, i64 1, !dbg !371 + %11157 = fptrunc <2 x float> %11156 to <2 x bfloat>, !dbg !371 + %11158 = insertelement <2 x float> poison, float %10999, i64 0, !dbg !371 + %11159 = insertelement <2 x float> %11158, float %11000, i64 1, !dbg !371 + %11160 = fptrunc <2 x float> %11159 to <2 x bfloat>, !dbg !371 + %11161 = insertelement <2 x float> poison, float %11001, i64 0, !dbg !371 + %11162 = insertelement <2 x float> %11161, float %11002, i64 1, !dbg !371 + %11163 = fptrunc <2 x float> %11162 to <2 x bfloat>, !dbg !371 + %11164 = insertelement <2 x float> poison, float %11003, i64 0, !dbg !371 + %11165 = insertelement <2 x float> %11164, float %11004, i64 1, !dbg !371 + %11166 = fptrunc <2 x float> %11165 to <2 x bfloat>, !dbg !371 + %11167 = insertelement <2 x float> poison, float %11005, i64 0, !dbg !371 + %11168 = insertelement <2 x float> %11167, float %11006, i64 1, !dbg !371 + %11169 = fptrunc <2 x float> %11168 to <2 x bfloat>, !dbg !371 + %11170 = insertelement <2 x float> poison, float %11007, i64 0, !dbg !371 + %11171 = insertelement <2 x float> %11170, float %11008, i64 1, !dbg !371 + %11172 = fptrunc <2 x float> %11171 to <2 x bfloat>, !dbg !371 + %11173 = insertelement <2 x float> poison, float %11009, i64 0, !dbg !371 + %11174 = insertelement <2 x float> %11173, float %11010, i64 1, !dbg !371 + %11175 = fptrunc <2 x float> %11174 to <2 x bfloat>, !dbg !371 + %11176 = insertelement <2 x float> poison, float %11011, i64 0, !dbg !371 + %11177 = insertelement <2 x float> %11176, float %11012, i64 1, !dbg !371 + %11178 = fptrunc <2 x float> %11177 to <2 x bfloat>, !dbg !371 + %11179 = insertelement <2 x float> poison, float %11013, i64 0, !dbg !371 + %11180 = insertelement <2 x float> %11179, float %11014, i64 1, !dbg !371 + %11181 = fptrunc <2 x float> %11180 to <2 x bfloat>, !dbg !371 + %11182 = insertelement <2 x float> poison, float %11015, i64 0, !dbg !371 + %11183 = insertelement <2 x float> %11182, float %11016, i64 1, !dbg !371 + %11184 = fptrunc <2 x float> %11183 to <2 x bfloat>, !dbg !371 + %11185 = insertelement <2 x float> poison, float %11017, i64 0, !dbg !371 + %11186 = insertelement <2 x float> %11185, float %11018, i64 1, !dbg !371 + %11187 = fptrunc <2 x float> %11186 to <2 x bfloat>, !dbg !371 + %11188 = insertelement <2 x float> poison, float %11019, i64 0, !dbg !371 + %11189 = insertelement <2 x float> %11188, float %11020, i64 1, !dbg !371 + %11190 = fptrunc <2 x float> %11189 to <2 x bfloat>, !dbg !371 + %11191 = insertelement <2 x float> poison, float %11021, i64 0, !dbg !371 + %11192 = insertelement <2 x float> %11191, float %11022, i64 1, !dbg !371 + %11193 = fptrunc <2 x float> %11192 to <2 x bfloat>, !dbg !371 + %11194 = insertelement <2 x float> poison, float %11023, i64 0, !dbg !371 + %11195 = insertelement <2 x float> %11194, float %11024, i64 1, !dbg !371 + %11196 = fptrunc <2 x float> %11195 to <2 x bfloat>, !dbg !371 + %11197 = insertelement <2 x float> poison, float %11025, i64 0, !dbg !371 + %11198 = insertelement <2 x float> %11197, float %11026, i64 1, !dbg !371 + %11199 = fptrunc <2 x float> %11198 to <2 x bfloat>, !dbg !371 + %11200 = insertelement <2 x float> poison, float %11027, i64 0, !dbg !371 + %11201 = insertelement <2 x float> %11200, float %11028, i64 1, !dbg !371 + %11202 = fptrunc <2 x float> %11201 to <2 x bfloat>, !dbg !371 + %11203 = insertelement <2 x float> poison, float %11029, i64 0, !dbg !371 + %11204 = insertelement <2 x float> %11203, float %11030, i64 1, !dbg !371 + %11205 = fptrunc <2 x float> %11204 to <2 x bfloat>, !dbg !371 + %11206 = insertelement <2 x float> poison, float %11031, i64 0, !dbg !371 + %11207 = insertelement <2 x float> %11206, float %11032, i64 1, !dbg !371 + %11208 = fptrunc <2 x float> %11207 to <2 x bfloat>, !dbg !371 + %11209 = insertelement <2 x float> poison, float %11033, i64 0, !dbg !371 + %11210 = insertelement <2 x float> %11209, float %11034, i64 1, !dbg !371 + %11211 = fptrunc <2 x float> %11210 to <2 x bfloat>, !dbg !371 + %11212 = shl nuw nsw i32 %4987, 13, !dbg !371 + %11213 = shl nuw nsw i32 %44, 5, !dbg !371 + %11214 = and i32 %11213, 7264, !dbg !371 + %11215 = and i32 %44, 24, !dbg !371 + %11216 = shl nuw nsw i32 %11215, 4, !dbg !371 + %11217 = shl nuw nsw i32 %44, 2, !dbg !371 + %11218 = and i32 %11217, 16, !dbg !371 + %11219 = or disjoint i32 %11212, %11218, !dbg !371 + %11220 = or disjoint i32 %11214, %11216, !dbg !371 + %11221 = or disjoint i32 %11219, %11220, !dbg !371 + %11222 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11221, !dbg !371 + %11223 = bitcast <2 x bfloat> %11118 to i32, !dbg !371 + %11224 = bitcast <2 x bfloat> %11124 to i32, !dbg !371 + %11225 = bitcast <2 x bfloat> %11130 to i32, !dbg !371 + %11226 = bitcast <2 x bfloat> %11136 to i32, !dbg !371 + %11227 = insertelement <4 x i32> poison, i32 %11223, i64 0, !dbg !371 + %11228 = insertelement <4 x i32> %11227, i32 %11224, i64 1, !dbg !371 + %11229 = insertelement <4 x i32> %11228, i32 %11225, i64 2, !dbg !371 + %11230 = insertelement <4 x i32> %11229, i32 %11226, i64 3, !dbg !371 + store <4 x i32> %11230, ptr addrspace(3) %11222, align 16, !dbg !371 + %11231 = getelementptr inbounds nuw i8, ptr addrspace(3) %11222, i32 512, !dbg !371 + %11232 = bitcast <2 x bfloat> %11121 to i32, !dbg !371 + %11233 = bitcast <2 x bfloat> %11127 to i32, !dbg !371 + %11234 = bitcast <2 x bfloat> %11133 to i32, !dbg !371 + %11235 = bitcast <2 x bfloat> %11139 to i32, !dbg !371 + %11236 = insertelement <4 x i32> poison, i32 %11232, i64 0, !dbg !371 + %11237 = insertelement <4 x i32> %11236, i32 %11233, i64 1, !dbg !371 + %11238 = insertelement <4 x i32> %11237, i32 %11234, i64 2, !dbg !371 + %11239 = insertelement <4 x i32> %11238, i32 %11235, i64 3, !dbg !371 + store <4 x i32> %11239, ptr addrspace(3) %11231, align 16, !dbg !371 + %11240 = xor i32 %11221, 32, !dbg !371 + %11241 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11240, !dbg !371 + %11242 = bitcast <2 x bfloat> %11142 to i32, !dbg !371 + %11243 = bitcast <2 x bfloat> %11148 to i32, !dbg !371 + %11244 = bitcast <2 x bfloat> %11154 to i32, !dbg !371 + %11245 = bitcast <2 x bfloat> %11160 to i32, !dbg !371 + %11246 = insertelement <4 x i32> poison, i32 %11242, i64 0, !dbg !371 + %11247 = insertelement <4 x i32> %11246, i32 %11243, i64 1, !dbg !371 + %11248 = insertelement <4 x i32> %11247, i32 %11244, i64 2, !dbg !371 + %11249 = insertelement <4 x i32> %11248, i32 %11245, i64 3, !dbg !371 + store <4 x i32> %11249, ptr addrspace(3) %11241, align 16, !dbg !371 + %11250 = getelementptr inbounds nuw i8, ptr addrspace(3) %11241, i32 512, !dbg !371 + %11251 = bitcast <2 x bfloat> %11145 to i32, !dbg !371 + %11252 = bitcast <2 x bfloat> %11151 to i32, !dbg !371 + %11253 = bitcast <2 x bfloat> %11157 to i32, !dbg !371 + %11254 = bitcast <2 x bfloat> %11163 to i32, !dbg !371 + %11255 = insertelement <4 x i32> poison, i32 %11251, i64 0, !dbg !371 + %11256 = insertelement <4 x i32> %11255, i32 %11252, i64 1, !dbg !371 + %11257 = insertelement <4 x i32> %11256, i32 %11253, i64 2, !dbg !371 + %11258 = insertelement <4 x i32> %11257, i32 %11254, i64 3, !dbg !371 + store <4 x i32> %11258, ptr addrspace(3) %11250, align 16, !dbg !371 + %11259 = xor i32 %11221, 64, !dbg !371 + %11260 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11259, !dbg !371 + %11261 = bitcast <2 x bfloat> %11166 to i32, !dbg !371 + %11262 = bitcast <2 x bfloat> %11172 to i32, !dbg !371 + %11263 = bitcast <2 x bfloat> %11178 to i32, !dbg !371 + %11264 = bitcast <2 x bfloat> %11184 to i32, !dbg !371 + %11265 = insertelement <4 x i32> poison, i32 %11261, i64 0, !dbg !371 + %11266 = insertelement <4 x i32> %11265, i32 %11262, i64 1, !dbg !371 + %11267 = insertelement <4 x i32> %11266, i32 %11263, i64 2, !dbg !371 + %11268 = insertelement <4 x i32> %11267, i32 %11264, i64 3, !dbg !371 + store <4 x i32> %11268, ptr addrspace(3) %11260, align 16, !dbg !371 + %11269 = getelementptr inbounds nuw i8, ptr addrspace(3) %11260, i32 512, !dbg !371 + %11270 = bitcast <2 x bfloat> %11169 to i32, !dbg !371 + %11271 = bitcast <2 x bfloat> %11175 to i32, !dbg !371 + %11272 = bitcast <2 x bfloat> %11181 to i32, !dbg !371 + %11273 = bitcast <2 x bfloat> %11187 to i32, !dbg !371 + %11274 = insertelement <4 x i32> poison, i32 %11270, i64 0, !dbg !371 + %11275 = insertelement <4 x i32> %11274, i32 %11271, i64 1, !dbg !371 + %11276 = insertelement <4 x i32> %11275, i32 %11272, i64 2, !dbg !371 + %11277 = insertelement <4 x i32> %11276, i32 %11273, i64 3, !dbg !371 + store <4 x i32> %11277, ptr addrspace(3) %11269, align 16, !dbg !371 + %11278 = xor i32 %11221, 96, !dbg !371 + %11279 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11278, !dbg !371 + %11280 = bitcast <2 x bfloat> %11190 to i32, !dbg !371 + %11281 = bitcast <2 x bfloat> %11196 to i32, !dbg !371 + %11282 = bitcast <2 x bfloat> %11202 to i32, !dbg !371 + %11283 = bitcast <2 x bfloat> %11208 to i32, !dbg !371 + %11284 = insertelement <4 x i32> poison, i32 %11280, i64 0, !dbg !371 + %11285 = insertelement <4 x i32> %11284, i32 %11281, i64 1, !dbg !371 + %11286 = insertelement <4 x i32> %11285, i32 %11282, i64 2, !dbg !371 + %11287 = insertelement <4 x i32> %11286, i32 %11283, i64 3, !dbg !371 + store <4 x i32> %11287, ptr addrspace(3) %11279, align 16, !dbg !371 + %11288 = getelementptr inbounds nuw i8, ptr addrspace(3) %11279, i32 512, !dbg !371 + %11289 = bitcast <2 x bfloat> %11193 to i32, !dbg !371 + %11290 = bitcast <2 x bfloat> %11199 to i32, !dbg !371 + %11291 = bitcast <2 x bfloat> %11205 to i32, !dbg !371 + %11292 = bitcast <2 x bfloat> %11211 to i32, !dbg !371 + %11293 = insertelement <4 x i32> poison, i32 %11289, i64 0, !dbg !371 + %11294 = insertelement <4 x i32> %11293, i32 %11290, i64 1, !dbg !371 + %11295 = insertelement <4 x i32> %11294, i32 %11291, i64 2, !dbg !371 + %11296 = insertelement <4 x i32> %11295, i32 %11292, i64 3, !dbg !371 + store <4 x i32> %11296, ptr addrspace(3) %11288, align 16, !dbg !371 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !371 + %11297 = shl nuw nsw i32 %11215, 10, !dbg !371 + %11298 = shl nuw nsw i32 %4987, 5, !dbg !371 + %11299 = and i32 %11217, 1008, !dbg !371 + %11300 = or disjoint i32 %11297, %11298, !dbg !371 + %11301 = xor i32 %11300, %11299, !dbg !371 + %11302 = getelementptr inbounds nuw i8, ptr addrspace(3) @global_smem, i32 %11301, !dbg !371 + %11303 = ptrtoint ptr addrspace(3) %11302 to i32, !dbg !371 + %11304 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11303) #3, !dbg !371 + %11305 = extractvalue { i32, i32, i32, i32 } %11304, 0, !dbg !371 + %11306 = extractvalue { i32, i32, i32, i32 } %11304, 1, !dbg !371 + %11307 = extractvalue { i32, i32, i32, i32 } %11304, 2, !dbg !371 + %11308 = extractvalue { i32, i32, i32, i32 } %11304, 3, !dbg !371 + %11309 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 1024, !dbg !371 + %11310 = ptrtoint ptr addrspace(3) %11309 to i32, !dbg !371 + %11311 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11310) #3, !dbg !371 + %11312 = extractvalue { i32, i32, i32, i32 } %11311, 0, !dbg !371 + %11313 = extractvalue { i32, i32, i32, i32 } %11311, 1, !dbg !371 + %11314 = extractvalue { i32, i32, i32, i32 } %11311, 2, !dbg !371 + %11315 = extractvalue { i32, i32, i32, i32 } %11311, 3, !dbg !371 + %11316 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 2048, !dbg !371 + %11317 = ptrtoint ptr addrspace(3) %11316 to i32, !dbg !371 + %11318 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11317) #3, !dbg !371 + %11319 = extractvalue { i32, i32, i32, i32 } %11318, 0, !dbg !371 + %11320 = extractvalue { i32, i32, i32, i32 } %11318, 1, !dbg !371 + %11321 = extractvalue { i32, i32, i32, i32 } %11318, 2, !dbg !371 + %11322 = extractvalue { i32, i32, i32, i32 } %11318, 3, !dbg !371 + %11323 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 3072, !dbg !371 + %11324 = ptrtoint ptr addrspace(3) %11323 to i32, !dbg !371 + %11325 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11324) #3, !dbg !371 + %11326 = extractvalue { i32, i32, i32, i32 } %11325, 0, !dbg !371 + %11327 = extractvalue { i32, i32, i32, i32 } %11325, 1, !dbg !371 + %11328 = extractvalue { i32, i32, i32, i32 } %11325, 2, !dbg !371 + %11329 = extractvalue { i32, i32, i32, i32 } %11325, 3, !dbg !371 + %11330 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 4096, !dbg !371 + %11331 = ptrtoint ptr addrspace(3) %11330 to i32, !dbg !371 + %11332 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11331) #3, !dbg !371 + %11333 = extractvalue { i32, i32, i32, i32 } %11332, 0, !dbg !371 + %11334 = extractvalue { i32, i32, i32, i32 } %11332, 1, !dbg !371 + %11335 = extractvalue { i32, i32, i32, i32 } %11332, 2, !dbg !371 + %11336 = extractvalue { i32, i32, i32, i32 } %11332, 3, !dbg !371 + %11337 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 5120, !dbg !371 + %11338 = ptrtoint ptr addrspace(3) %11337 to i32, !dbg !371 + %11339 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11338) #3, !dbg !371 + %11340 = extractvalue { i32, i32, i32, i32 } %11339, 0, !dbg !371 + %11341 = extractvalue { i32, i32, i32, i32 } %11339, 1, !dbg !371 + %11342 = extractvalue { i32, i32, i32, i32 } %11339, 2, !dbg !371 + %11343 = extractvalue { i32, i32, i32, i32 } %11339, 3, !dbg !371 + %11344 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 6144, !dbg !371 + %11345 = ptrtoint ptr addrspace(3) %11344 to i32, !dbg !371 + %11346 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11345) #3, !dbg !371 + %11347 = extractvalue { i32, i32, i32, i32 } %11346, 0, !dbg !371 + %11348 = extractvalue { i32, i32, i32, i32 } %11346, 1, !dbg !371 + %11349 = extractvalue { i32, i32, i32, i32 } %11346, 2, !dbg !371 + %11350 = extractvalue { i32, i32, i32, i32 } %11346, 3, !dbg !371 + %11351 = getelementptr inbounds nuw i8, ptr addrspace(3) %11302, i32 7168, !dbg !371 + %11352 = ptrtoint ptr addrspace(3) %11351 to i32, !dbg !371 + %11353 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11352) #3, !dbg !371 + %11354 = extractvalue { i32, i32, i32, i32 } %11353, 0, !dbg !371 + %11355 = extractvalue { i32, i32, i32, i32 } %11353, 1, !dbg !371 + %11356 = extractvalue { i32, i32, i32, i32 } %11353, 2, !dbg !371 + %11357 = extractvalue { i32, i32, i32, i32 } %11353, 3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11305, i32 %11306, i32 %11307, i32 %11308, ptr addrspace(1) %11108, i1 %4777) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11312, i32 %11313, i32 %11314, i32 %11315, ptr addrspace(1) %11109, i1 %4778) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11319, i32 %11320, i32 %11321, i32 %11322, ptr addrspace(1) %11110, i1 %4779) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11326, i32 %11327, i32 %11328, i32 %11329, ptr addrspace(1) %11111, i1 %4780) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11333, i32 %11334, i32 %11335, i32 %11336, ptr addrspace(1) %11112, i1 %4781) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11340, i32 %11341, i32 %11342, i32 %11343, ptr addrspace(1) %11113, i1 %4782) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11347, i32 %11348, i32 %11349, i32 %11350, ptr addrspace(1) %11114, i1 %4783) #3, !dbg !371 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11354, i32 %11355, i32 %11356, i32 %11357, ptr addrspace(1) %11115, i1 %4784) #3, !dbg !371 + %11358 = insertelement <2 x float> poison, float %11035, i64 0, !dbg !372 + %11359 = insertelement <2 x float> %11358, float %11036, i64 1, !dbg !372 + %11360 = fmul <2 x float> %11359, splat (float 0x3FB6A09E60000000), !dbg !372 + %11361 = insertelement <2 x float> poison, float %11037, i64 0, !dbg !372 + %11362 = insertelement <2 x float> %11361, float %11038, i64 1, !dbg !372 + %11363 = fmul <2 x float> %11362, splat (float 0x3FB6A09E60000000), !dbg !372 + %11364 = insertelement <2 x float> poison, float %11039, i64 0, !dbg !372 + %11365 = insertelement <2 x float> %11364, float %11040, i64 1, !dbg !372 + %11366 = fmul <2 x float> %11365, splat (float 0x3FB6A09E60000000), !dbg !372 + %11367 = insertelement <2 x float> poison, float %11041, i64 0, !dbg !372 + %11368 = insertelement <2 x float> %11367, float %11042, i64 1, !dbg !372 + %11369 = fmul <2 x float> %11368, splat (float 0x3FB6A09E60000000), !dbg !372 + %11370 = insertelement <2 x float> poison, float %11043, i64 0, !dbg !372 + %11371 = insertelement <2 x float> %11370, float %11044, i64 1, !dbg !372 + %11372 = fmul <2 x float> %11371, splat (float 0x3FB6A09E60000000), !dbg !372 + %11373 = insertelement <2 x float> poison, float %11045, i64 0, !dbg !372 + %11374 = insertelement <2 x float> %11373, float %11046, i64 1, !dbg !372 + %11375 = fmul <2 x float> %11374, splat (float 0x3FB6A09E60000000), !dbg !372 + %11376 = insertelement <2 x float> poison, float %11047, i64 0, !dbg !372 + %11377 = insertelement <2 x float> %11376, float %11048, i64 1, !dbg !372 + %11378 = fmul <2 x float> %11377, splat (float 0x3FB6A09E60000000), !dbg !372 + %11379 = insertelement <2 x float> poison, float %11049, i64 0, !dbg !372 + %11380 = insertelement <2 x float> %11379, float %11050, i64 1, !dbg !372 + %11381 = fmul <2 x float> %11380, splat (float 0x3FB6A09E60000000), !dbg !372 + %11382 = insertelement <2 x float> poison, float %11051, i64 0, !dbg !372 + %11383 = insertelement <2 x float> %11382, float %11052, i64 1, !dbg !372 + %11384 = fmul <2 x float> %11383, splat (float 0x3FB6A09E60000000), !dbg !372 + %11385 = insertelement <2 x float> poison, float %11053, i64 0, !dbg !372 + %11386 = insertelement <2 x float> %11385, float %11054, i64 1, !dbg !372 + %11387 = fmul <2 x float> %11386, splat (float 0x3FB6A09E60000000), !dbg !372 + %11388 = insertelement <2 x float> poison, float %11055, i64 0, !dbg !372 + %11389 = insertelement <2 x float> %11388, float %11056, i64 1, !dbg !372 + %11390 = fmul <2 x float> %11389, splat (float 0x3FB6A09E60000000), !dbg !372 + %11391 = insertelement <2 x float> poison, float %11057, i64 0, !dbg !372 + %11392 = insertelement <2 x float> %11391, float %11058, i64 1, !dbg !372 + %11393 = fmul <2 x float> %11392, splat (float 0x3FB6A09E60000000), !dbg !372 + %11394 = insertelement <2 x float> poison, float %11059, i64 0, !dbg !372 + %11395 = insertelement <2 x float> %11394, float %11060, i64 1, !dbg !372 + %11396 = fmul <2 x float> %11395, splat (float 0x3FB6A09E60000000), !dbg !372 + %11397 = insertelement <2 x float> poison, float %11061, i64 0, !dbg !372 + %11398 = insertelement <2 x float> %11397, float %11062, i64 1, !dbg !372 + %11399 = fmul <2 x float> %11398, splat (float 0x3FB6A09E60000000), !dbg !372 + %11400 = insertelement <2 x float> poison, float %11063, i64 0, !dbg !372 + %11401 = insertelement <2 x float> %11400, float %11064, i64 1, !dbg !372 + %11402 = fmul <2 x float> %11401, splat (float 0x3FB6A09E60000000), !dbg !372 + %11403 = insertelement <2 x float> poison, float %11065, i64 0, !dbg !372 + %11404 = insertelement <2 x float> %11403, float %11066, i64 1, !dbg !372 + %11405 = fmul <2 x float> %11404, splat (float 0x3FB6A09E60000000), !dbg !372 + %11406 = insertelement <2 x float> poison, float %11067, i64 0, !dbg !372 + %11407 = insertelement <2 x float> %11406, float %11068, i64 1, !dbg !372 + %11408 = fmul <2 x float> %11407, splat (float 0x3FB6A09E60000000), !dbg !372 + %11409 = insertelement <2 x float> poison, float %11069, i64 0, !dbg !372 + %11410 = insertelement <2 x float> %11409, float %11070, i64 1, !dbg !372 + %11411 = fmul <2 x float> %11410, splat (float 0x3FB6A09E60000000), !dbg !372 + %11412 = insertelement <2 x float> poison, float %11071, i64 0, !dbg !372 + %11413 = insertelement <2 x float> %11412, float %11072, i64 1, !dbg !372 + %11414 = fmul <2 x float> %11413, splat (float 0x3FB6A09E60000000), !dbg !372 + %11415 = insertelement <2 x float> poison, float %11073, i64 0, !dbg !372 + %11416 = insertelement <2 x float> %11415, float %11074, i64 1, !dbg !372 + %11417 = fmul <2 x float> %11416, splat (float 0x3FB6A09E60000000), !dbg !372 + %11418 = insertelement <2 x float> poison, float %11075, i64 0, !dbg !372 + %11419 = insertelement <2 x float> %11418, float %11076, i64 1, !dbg !372 + %11420 = fmul <2 x float> %11419, splat (float 0x3FB6A09E60000000), !dbg !372 + %11421 = insertelement <2 x float> poison, float %11077, i64 0, !dbg !372 + %11422 = insertelement <2 x float> %11421, float %11078, i64 1, !dbg !372 + %11423 = fmul <2 x float> %11422, splat (float 0x3FB6A09E60000000), !dbg !372 + %11424 = insertelement <2 x float> poison, float %11079, i64 0, !dbg !372 + %11425 = insertelement <2 x float> %11424, float %11080, i64 1, !dbg !372 + %11426 = fmul <2 x float> %11425, splat (float 0x3FB6A09E60000000), !dbg !372 + %11427 = insertelement <2 x float> poison, float %11081, i64 0, !dbg !372 + %11428 = insertelement <2 x float> %11427, float %11082, i64 1, !dbg !372 + %11429 = fmul <2 x float> %11428, splat (float 0x3FB6A09E60000000), !dbg !372 + %11430 = insertelement <2 x float> poison, float %11083, i64 0, !dbg !372 + %11431 = insertelement <2 x float> %11430, float %11084, i64 1, !dbg !372 + %11432 = fmul <2 x float> %11431, splat (float 0x3FB6A09E60000000), !dbg !372 + %11433 = insertelement <2 x float> poison, float %11085, i64 0, !dbg !372 + %11434 = insertelement <2 x float> %11433, float %11086, i64 1, !dbg !372 + %11435 = fmul <2 x float> %11434, splat (float 0x3FB6A09E60000000), !dbg !372 + %11436 = insertelement <2 x float> poison, float %11087, i64 0, !dbg !372 + %11437 = insertelement <2 x float> %11436, float %11088, i64 1, !dbg !372 + %11438 = fmul <2 x float> %11437, splat (float 0x3FB6A09E60000000), !dbg !372 + %11439 = insertelement <2 x float> poison, float %11089, i64 0, !dbg !372 + %11440 = insertelement <2 x float> %11439, float %11090, i64 1, !dbg !372 + %11441 = fmul <2 x float> %11440, splat (float 0x3FB6A09E60000000), !dbg !372 + %11442 = insertelement <2 x float> poison, float %11091, i64 0, !dbg !372 + %11443 = insertelement <2 x float> %11442, float %11092, i64 1, !dbg !372 + %11444 = fmul <2 x float> %11443, splat (float 0x3FB6A09E60000000), !dbg !372 + %11445 = insertelement <2 x float> poison, float %11093, i64 0, !dbg !372 + %11446 = insertelement <2 x float> %11445, float %11094, i64 1, !dbg !372 + %11447 = fmul <2 x float> %11446, splat (float 0x3FB6A09E60000000), !dbg !372 + %11448 = insertelement <2 x float> poison, float %11095, i64 0, !dbg !372 + %11449 = insertelement <2 x float> %11448, float %11096, i64 1, !dbg !372 + %11450 = fmul <2 x float> %11449, splat (float 0x3FB6A09E60000000), !dbg !372 + %11451 = insertelement <2 x float> poison, float %11097, i64 0, !dbg !372 + %11452 = insertelement <2 x float> %11451, float %11098, i64 1, !dbg !372 + %11453 = fmul <2 x float> %11452, splat (float 0x3FB6A09E60000000), !dbg !372 + %11454 = or disjoint i32 %4767, %35, !dbg !373 + %11455 = add i32 %4742, %11454, !dbg !374 + %11456 = add i32 %4743, %11454, !dbg !374 + %11457 = add i32 %4744, %11454, !dbg !374 + %11458 = add i32 %4745, %11454, !dbg !374 + %11459 = add i32 %4746, %11454, !dbg !374 + %11460 = add i32 %4747, %11454, !dbg !374 + %11461 = add i32 %4748, %11454, !dbg !374 + %11462 = add i32 %4749, %11454, !dbg !374 + %11463 = sext i32 %11455 to i64, !dbg !375 + %11464 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11463, !dbg !375 + %11465 = sext i32 %11456 to i64, !dbg !375 + %11466 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11465, !dbg !375 + %11467 = sext i32 %11457 to i64, !dbg !375 + %11468 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11467, !dbg !375 + %11469 = sext i32 %11458 to i64, !dbg !375 + %11470 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11469, !dbg !375 + %11471 = sext i32 %11459 to i64, !dbg !375 + %11472 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11471, !dbg !375 + %11473 = sext i32 %11460 to i64, !dbg !375 + %11474 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11473, !dbg !375 + %11475 = sext i32 %11461 to i64, !dbg !375 + %11476 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11475, !dbg !375 + %11477 = sext i32 %11462 to i64, !dbg !375 + %11478 = getelementptr bfloat, ptr addrspace(1) %16, i64 %11477, !dbg !375 + %11479 = fptrunc <2 x float> %11360 to <2 x bfloat>, !dbg !376 + %11480 = fptrunc <2 x float> %11363 to <2 x bfloat>, !dbg !376 + %11481 = fptrunc <2 x float> %11366 to <2 x bfloat>, !dbg !376 + %11482 = fptrunc <2 x float> %11369 to <2 x bfloat>, !dbg !376 + %11483 = fptrunc <2 x float> %11372 to <2 x bfloat>, !dbg !376 + %11484 = fptrunc <2 x float> %11375 to <2 x bfloat>, !dbg !376 + %11485 = fptrunc <2 x float> %11378 to <2 x bfloat>, !dbg !376 + %11486 = fptrunc <2 x float> %11381 to <2 x bfloat>, !dbg !376 + %11487 = fptrunc <2 x float> %11384 to <2 x bfloat>, !dbg !376 + %11488 = fptrunc <2 x float> %11387 to <2 x bfloat>, !dbg !376 + %11489 = fptrunc <2 x float> %11390 to <2 x bfloat>, !dbg !376 + %11490 = fptrunc <2 x float> %11393 to <2 x bfloat>, !dbg !376 + %11491 = fptrunc <2 x float> %11396 to <2 x bfloat>, !dbg !376 + %11492 = fptrunc <2 x float> %11399 to <2 x bfloat>, !dbg !376 + %11493 = fptrunc <2 x float> %11402 to <2 x bfloat>, !dbg !376 + %11494 = fptrunc <2 x float> %11405 to <2 x bfloat>, !dbg !376 + %11495 = fptrunc <2 x float> %11408 to <2 x bfloat>, !dbg !376 + %11496 = fptrunc <2 x float> %11411 to <2 x bfloat>, !dbg !376 + %11497 = fptrunc <2 x float> %11414 to <2 x bfloat>, !dbg !376 + %11498 = fptrunc <2 x float> %11417 to <2 x bfloat>, !dbg !376 + %11499 = fptrunc <2 x float> %11420 to <2 x bfloat>, !dbg !376 + %11500 = fptrunc <2 x float> %11423 to <2 x bfloat>, !dbg !376 + %11501 = fptrunc <2 x float> %11426 to <2 x bfloat>, !dbg !376 + %11502 = fptrunc <2 x float> %11429 to <2 x bfloat>, !dbg !376 + %11503 = fptrunc <2 x float> %11432 to <2 x bfloat>, !dbg !376 + %11504 = fptrunc <2 x float> %11435 to <2 x bfloat>, !dbg !376 + %11505 = fptrunc <2 x float> %11438 to <2 x bfloat>, !dbg !376 + %11506 = fptrunc <2 x float> %11441 to <2 x bfloat>, !dbg !376 + %11507 = fptrunc <2 x float> %11444 to <2 x bfloat>, !dbg !376 + %11508 = fptrunc <2 x float> %11447 to <2 x bfloat>, !dbg !376 + %11509 = fptrunc <2 x float> %11450 to <2 x bfloat>, !dbg !376 + %11510 = fptrunc <2 x float> %11453 to <2 x bfloat>, !dbg !376 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !376 + %11511 = bitcast <2 x bfloat> %11479 to i32, !dbg !376 + %11512 = bitcast <2 x bfloat> %11481 to i32, !dbg !376 + %11513 = bitcast <2 x bfloat> %11483 to i32, !dbg !376 + %11514 = bitcast <2 x bfloat> %11485 to i32, !dbg !376 + %11515 = insertelement <4 x i32> poison, i32 %11511, i64 0, !dbg !376 + %11516 = insertelement <4 x i32> %11515, i32 %11512, i64 1, !dbg !376 + %11517 = insertelement <4 x i32> %11516, i32 %11513, i64 2, !dbg !376 + %11518 = insertelement <4 x i32> %11517, i32 %11514, i64 3, !dbg !376 + store <4 x i32> %11518, ptr addrspace(3) %11222, align 16, !dbg !376 + %11519 = bitcast <2 x bfloat> %11480 to i32, !dbg !376 + %11520 = bitcast <2 x bfloat> %11482 to i32, !dbg !376 + %11521 = bitcast <2 x bfloat> %11484 to i32, !dbg !376 + %11522 = bitcast <2 x bfloat> %11486 to i32, !dbg !376 + %11523 = insertelement <4 x i32> poison, i32 %11519, i64 0, !dbg !376 + %11524 = insertelement <4 x i32> %11523, i32 %11520, i64 1, !dbg !376 + %11525 = insertelement <4 x i32> %11524, i32 %11521, i64 2, !dbg !376 + %11526 = insertelement <4 x i32> %11525, i32 %11522, i64 3, !dbg !376 + store <4 x i32> %11526, ptr addrspace(3) %11231, align 16, !dbg !376 + %11527 = bitcast <2 x bfloat> %11487 to i32, !dbg !376 + %11528 = bitcast <2 x bfloat> %11489 to i32, !dbg !376 + %11529 = bitcast <2 x bfloat> %11491 to i32, !dbg !376 + %11530 = bitcast <2 x bfloat> %11493 to i32, !dbg !376 + %11531 = insertelement <4 x i32> poison, i32 %11527, i64 0, !dbg !376 + %11532 = insertelement <4 x i32> %11531, i32 %11528, i64 1, !dbg !376 + %11533 = insertelement <4 x i32> %11532, i32 %11529, i64 2, !dbg !376 + %11534 = insertelement <4 x i32> %11533, i32 %11530, i64 3, !dbg !376 + store <4 x i32> %11534, ptr addrspace(3) %11241, align 16, !dbg !376 + %11535 = bitcast <2 x bfloat> %11488 to i32, !dbg !376 + %11536 = bitcast <2 x bfloat> %11490 to i32, !dbg !376 + %11537 = bitcast <2 x bfloat> %11492 to i32, !dbg !376 + %11538 = bitcast <2 x bfloat> %11494 to i32, !dbg !376 + %11539 = insertelement <4 x i32> poison, i32 %11535, i64 0, !dbg !376 + %11540 = insertelement <4 x i32> %11539, i32 %11536, i64 1, !dbg !376 + %11541 = insertelement <4 x i32> %11540, i32 %11537, i64 2, !dbg !376 + %11542 = insertelement <4 x i32> %11541, i32 %11538, i64 3, !dbg !376 + store <4 x i32> %11542, ptr addrspace(3) %11250, align 16, !dbg !376 + %11543 = bitcast <2 x bfloat> %11495 to i32, !dbg !376 + %11544 = bitcast <2 x bfloat> %11497 to i32, !dbg !376 + %11545 = bitcast <2 x bfloat> %11499 to i32, !dbg !376 + %11546 = bitcast <2 x bfloat> %11501 to i32, !dbg !376 + %11547 = insertelement <4 x i32> poison, i32 %11543, i64 0, !dbg !376 + %11548 = insertelement <4 x i32> %11547, i32 %11544, i64 1, !dbg !376 + %11549 = insertelement <4 x i32> %11548, i32 %11545, i64 2, !dbg !376 + %11550 = insertelement <4 x i32> %11549, i32 %11546, i64 3, !dbg !376 + store <4 x i32> %11550, ptr addrspace(3) %11260, align 16, !dbg !376 + %11551 = bitcast <2 x bfloat> %11496 to i32, !dbg !376 + %11552 = bitcast <2 x bfloat> %11498 to i32, !dbg !376 + %11553 = bitcast <2 x bfloat> %11500 to i32, !dbg !376 + %11554 = bitcast <2 x bfloat> %11502 to i32, !dbg !376 + %11555 = insertelement <4 x i32> poison, i32 %11551, i64 0, !dbg !376 + %11556 = insertelement <4 x i32> %11555, i32 %11552, i64 1, !dbg !376 + %11557 = insertelement <4 x i32> %11556, i32 %11553, i64 2, !dbg !376 + %11558 = insertelement <4 x i32> %11557, i32 %11554, i64 3, !dbg !376 + store <4 x i32> %11558, ptr addrspace(3) %11269, align 16, !dbg !376 + %11559 = bitcast <2 x bfloat> %11503 to i32, !dbg !376 + %11560 = bitcast <2 x bfloat> %11505 to i32, !dbg !376 + %11561 = bitcast <2 x bfloat> %11507 to i32, !dbg !376 + %11562 = bitcast <2 x bfloat> %11509 to i32, !dbg !376 + %11563 = insertelement <4 x i32> poison, i32 %11559, i64 0, !dbg !376 + %11564 = insertelement <4 x i32> %11563, i32 %11560, i64 1, !dbg !376 + %11565 = insertelement <4 x i32> %11564, i32 %11561, i64 2, !dbg !376 + %11566 = insertelement <4 x i32> %11565, i32 %11562, i64 3, !dbg !376 + store <4 x i32> %11566, ptr addrspace(3) %11279, align 16, !dbg !376 + %11567 = bitcast <2 x bfloat> %11504 to i32, !dbg !376 + %11568 = bitcast <2 x bfloat> %11506 to i32, !dbg !376 + %11569 = bitcast <2 x bfloat> %11508 to i32, !dbg !376 + %11570 = bitcast <2 x bfloat> %11510 to i32, !dbg !376 + %11571 = insertelement <4 x i32> poison, i32 %11567, i64 0, !dbg !376 + %11572 = insertelement <4 x i32> %11571, i32 %11568, i64 1, !dbg !376 + %11573 = insertelement <4 x i32> %11572, i32 %11569, i64 2, !dbg !376 + %11574 = insertelement <4 x i32> %11573, i32 %11570, i64 3, !dbg !376 + store <4 x i32> %11574, ptr addrspace(3) %11288, align 16, !dbg !376 + tail call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0), !dbg !376 + %11575 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11303) #3, !dbg !376 + %11576 = extractvalue { i32, i32, i32, i32 } %11575, 0, !dbg !376 + %11577 = extractvalue { i32, i32, i32, i32 } %11575, 1, !dbg !376 + %11578 = extractvalue { i32, i32, i32, i32 } %11575, 2, !dbg !376 + %11579 = extractvalue { i32, i32, i32, i32 } %11575, 3, !dbg !376 + %11580 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11310) #3, !dbg !376 + %11581 = extractvalue { i32, i32, i32, i32 } %11580, 0, !dbg !376 + %11582 = extractvalue { i32, i32, i32, i32 } %11580, 1, !dbg !376 + %11583 = extractvalue { i32, i32, i32, i32 } %11580, 2, !dbg !376 + %11584 = extractvalue { i32, i32, i32, i32 } %11580, 3, !dbg !376 + %11585 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11317) #3, !dbg !376 + %11586 = extractvalue { i32, i32, i32, i32 } %11585, 0, !dbg !376 + %11587 = extractvalue { i32, i32, i32, i32 } %11585, 1, !dbg !376 + %11588 = extractvalue { i32, i32, i32, i32 } %11585, 2, !dbg !376 + %11589 = extractvalue { i32, i32, i32, i32 } %11585, 3, !dbg !376 + %11590 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11324) #3, !dbg !376 + %11591 = extractvalue { i32, i32, i32, i32 } %11590, 0, !dbg !376 + %11592 = extractvalue { i32, i32, i32, i32 } %11590, 1, !dbg !376 + %11593 = extractvalue { i32, i32, i32, i32 } %11590, 2, !dbg !376 + %11594 = extractvalue { i32, i32, i32, i32 } %11590, 3, !dbg !376 + %11595 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11331) #3, !dbg !376 + %11596 = extractvalue { i32, i32, i32, i32 } %11595, 0, !dbg !376 + %11597 = extractvalue { i32, i32, i32, i32 } %11595, 1, !dbg !376 + %11598 = extractvalue { i32, i32, i32, i32 } %11595, 2, !dbg !376 + %11599 = extractvalue { i32, i32, i32, i32 } %11595, 3, !dbg !376 + %11600 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11338) #3, !dbg !376 + %11601 = extractvalue { i32, i32, i32, i32 } %11600, 0, !dbg !376 + %11602 = extractvalue { i32, i32, i32, i32 } %11600, 1, !dbg !376 + %11603 = extractvalue { i32, i32, i32, i32 } %11600, 2, !dbg !376 + %11604 = extractvalue { i32, i32, i32, i32 } %11600, 3, !dbg !376 + %11605 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11345) #3, !dbg !376 + %11606 = extractvalue { i32, i32, i32, i32 } %11605, 0, !dbg !376 + %11607 = extractvalue { i32, i32, i32, i32 } %11605, 1, !dbg !376 + %11608 = extractvalue { i32, i32, i32, i32 } %11605, 2, !dbg !376 + %11609 = extractvalue { i32, i32, i32, i32 } %11605, 3, !dbg !376 + %11610 = tail call { i32, i32, i32, i32 } asm sideeffect "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];", "=r,=r,=r,=r,r"(i32 %11352) #3, !dbg !376 + %11611 = extractvalue { i32, i32, i32, i32 } %11610, 0, !dbg !376 + %11612 = extractvalue { i32, i32, i32, i32 } %11610, 1, !dbg !376 + %11613 = extractvalue { i32, i32, i32, i32 } %11610, 2, !dbg !376 + %11614 = extractvalue { i32, i32, i32, i32 } %11610, 3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11576, i32 %11577, i32 %11578, i32 %11579, ptr addrspace(1) %11464, i1 %4777) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11581, i32 %11582, i32 %11583, i32 %11584, ptr addrspace(1) %11466, i1 %4778) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11586, i32 %11587, i32 %11588, i32 %11589, ptr addrspace(1) %11468, i1 %4779) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11591, i32 %11592, i32 %11593, i32 %11594, ptr addrspace(1) %11470, i1 %4780) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11596, i32 %11597, i32 %11598, i32 %11599, ptr addrspace(1) %11472, i1 %4781) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11601, i32 %11602, i32 %11603, i32 %11604, ptr addrspace(1) %11474, i1 %4782) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11606, i32 %11607, i32 %11608, i32 %11609, ptr addrspace(1) %11476, i1 %4783) #3, !dbg !376 + tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %11611, i32 %11612, i32 %11613, i32 %11614, ptr addrspace(1) %11478, i1 %4784) #3, !dbg !376 + br label %11615, !dbg !35 + +11615: ; preds = %._crit_edge1673, %11099 + ret void, !dbg !377 +} + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smax.i32(i32, i32) #1 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.smin.i32(i32, i32) #1 + +; Function Attrs: convergent nocallback nounwind +declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) #2 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.commit.group() #3 + +; Function Attrs: nounwind +declare void @llvm.nvvm.cp.async.wait.group(i32 immarg) #3 + +; Function Attrs: convergent nocallback nounwind memory(inaccessiblemem: readwrite) +declare i32 @llvm.nvvm.shfl.sync.idx.i32(i32, i32, i32, i32) #4 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.fence.sync.aligned() #5 + +; Function Attrs: convergent nounwind +declare void @llvm.nvvm.wgmma.commit_group.sync.aligned() #5 + +declare i32 @__nvvm_reflect(ptr) local_unnamed_addr #6 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.ftz.f(float) #7 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare float @llvm.nvvm.ex2.approx.f(float) #7 + +attributes #0 = { nounwind "nvvm.reqntid"="256" } +attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #2 = { convergent nocallback nounwind } +attributes #3 = { nounwind } +attributes #4 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) } +attributes #5 = { convergent nounwind } +attributes #6 = { "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #7 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) } + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2, !3} +!llvm.ident = !{!4} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly) +!1 = !DIFile(filename: "cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py", directory: "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = !{i32 4, !"nvvm-reflect-ftz", i32 1} +!4 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"} +!5 = distinct !DISubprogram(name: "triton_tem_fused_mul_1", linkageName: "triton_tem_fused_mul_1", scope: !1, file: !1, line: 18, type: !6, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!6 = !DISubroutineType(cc: DW_CC_normal, types: !7) +!7 = !{} +!8 = !DILocation(line: 94, column: 54, scope: !5) +!9 = !DILocation(line: 97, column: 74, scope: !5) +!10 = !DILocation(line: 97, column: 66, scope: !5) +!11 = !DILocation(line: 97, column: 100, scope: !5) +!12 = !DILocation(line: 97, column: 91, scope: !5) +!13 = !DILocation(line: 97, column: 82, scope: !5) +!14 = !DILocation(line: 97, column: 59, scope: !5) +!15 = !DILocation(line: 97, column: 111, scope: !5) +!16 = !DILocation(line: 111, column: 24, scope: !5) +!17 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !20) +!18 = distinct !DILexicalBlockFile(scope: !5, file: !19, discriminator: 0) +!19 = !DIFile(filename: "standard.py", directory: "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language") +!20 = !DILocation(line: 112, column: 36, scope: !5) +!21 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !20) +!22 = !DILocation(line: 115, column: 27, scope: !5) +!23 = !DILocation(line: 116, column: 28, scope: !5) +!24 = !DILocation(line: 124, column: 25, scope: !5) +!25 = !DILocation(line: 124, column: 59, scope: !5) +!26 = !DILocation(line: 100, column: 58, scope: !5) +!27 = !DILocation(line: 128, column: 50, scope: !5) +!28 = !DILocation(line: 128, column: 37, scope: !5) +!29 = !DILocation(line: 128, column: 61, scope: !5) +!30 = !DILocation(line: 131, column: 9, scope: !5) +!31 = !DILocation(line: 132, column: 9, scope: !5) +!32 = !DILocation(line: 133, column: 10, scope: !5) +!33 = !DILocation(line: 136, column: 26, scope: !5) +!34 = !DILocation(line: 139, column: 14, scope: !5) +!35 = !DILocation(line: 139, column: 7, scope: !5) +!36 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !37) +!37 = !DILocation(line: 113, column: 34, scope: !5) +!38 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !37) +!39 = !DILocation(line: 140, column: 24, scope: !5) +!40 = !DILocation(line: 144, column: 29, scope: !5) +!41 = !DILocation(line: 144, column: 54, scope: !5) +!42 = !DILocation(line: 144, column: 44, scope: !5) +!43 = !DILocation(line: 145, column: 35, scope: !5) +!44 = !DILocation(line: 158, column: 30, scope: !5) +!45 = !DILocation(line: 158, column: 52, scope: !5) +!46 = !DILocation(line: 158, column: 40, scope: !5) +!47 = !DILocation(line: 158, column: 63, scope: !5) +!48 = !DILocation(line: 159, column: 32, scope: !5) +!49 = !DILocation(line: 159, column: 55, scope: !5) +!50 = !DILocation(line: 159, column: 42, scope: !5) +!51 = !DILocation(line: 159, column: 66, scope: !5) +!52 = !DILocation(line: 161, column: 30, scope: !5) +!53 = !DILocation(line: 161, column: 35, scope: !5) +!54 = !DILocation(line: 161, column: 46, scope: !5) +!55 = !DILocation(line: 161, column: 56, scope: !5) +!56 = !DILocation(line: 163, column: 17, scope: !5) +!57 = !DILocation(line: 164, column: 19, scope: !5) +!58 = !DILocation(line: 167, column: 19, scope: !5) +!59 = !DILocation(line: 168, column: 21, scope: !5) +!60 = !DILocation(line: 169, column: 25, scope: !5) +!61 = !DILocation(line: 174, column: 36, scope: !5) +!62 = !DILocation(line: 175, column: 29, scope: !5) +!63 = !DILocation(line: 789, column: 38, scope: !64, inlinedAt: !65) +!64 = distinct !DILexicalBlockFile(scope: !5, file: !1, discriminator: 0) +!65 = !DILocation(line: 178, column: 107, scope: !5) +!66 = !DILocation(line: 789, column: 20, scope: !64, inlinedAt: !65) +!67 = !DILocation(line: 789, column: 56, scope: !64, inlinedAt: !65) +!68 = !DILocation(line: 789, column: 49, scope: !64, inlinedAt: !65) +!69 = !DILocation(line: 797, column: 52, scope: !64, inlinedAt: !65) +!70 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !65) +!71 = !DILocation(line: 789, column: 38, scope: !64, inlinedAt: !72) +!72 = !DILocation(line: 179, column: 111, scope: !5) +!73 = !DILocation(line: 789, column: 20, scope: !64, inlinedAt: !72) +!74 = !DILocation(line: 789, column: 49, scope: !64, inlinedAt: !72) +!75 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !72) +!76 = !DILocation(line: 188, column: 58, scope: !5) +!77 = !DILocation(line: 188, column: 34, scope: !5) +!78 = !DILocation(line: 188, column: 25, scope: !5) +!79 = !DILocation(line: 189, column: 33, scope: !5) +!80 = !DILocation(line: 189, column: 26, scope: !5) +!81 = !DILocation(line: 190, column: 30, scope: !5) +!82 = !DILocation(line: 190, column: 50, scope: !5) +!83 = !DILocation(line: 195, column: 30, scope: !5) +!84 = !DILocation(line: 196, column: 27, scope: !5) +!85 = !DILocation(line: 196, column: 41, scope: !5) +!86 = !DILocation(line: 197, column: 53, scope: !5) +!87 = !DILocation(line: 197, column: 39, scope: !5) +!88 = !DILocation(line: 199, column: 42, scope: !5) +!89 = !DILocation(line: 199, column: 29, scope: !5) +!90 = !DILocation(line: 390, column: 37, scope: !64, inlinedAt: !91) +!91 = !DILocation(line: 207, column: 12, scope: !5) +!92 = !DILocation(line: 390, column: 18, scope: !64, inlinedAt: !91) +!93 = !DILocation(line: 390, column: 49, scope: !64, inlinedAt: !91) +!94 = !DILocation(line: 391, column: 18, scope: !64, inlinedAt: !91) +!95 = !DILocation(line: 391, column: 49, scope: !64, inlinedAt: !91) +!96 = !DILocation(line: 395, column: 43, scope: !64, inlinedAt: !91) +!97 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !91) +!98 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !91) +!99 = !DILocation(line: 395, column: 101, scope: !64, inlinedAt: !91) +!100 = !DILocation(line: 395, column: 63, scope: !64, inlinedAt: !91) +!101 = !DILocation(line: 397, column: 28, scope: !64, inlinedAt: !91) +!102 = !DILocation(line: 795, column: 52, scope: !64, inlinedAt: !91) +!103 = !DILocation(line: 795, column: 23, scope: !64, inlinedAt: !91) +!104 = !DILocation(line: 414, column: 19, scope: !64, inlinedAt: !91) +!105 = !DILocation(line: 415, column: 19, scope: !64, inlinedAt: !91) +!106 = !DILocation(line: 417, column: 19, scope: !64, inlinedAt: !91) +!107 = !DILocation(line: 459, column: 19, scope: !64, inlinedAt: !91) +!108 = !DILocation(line: 762, column: 21, scope: !64, inlinedAt: !91) +!109 = !DILocation(line: 492, column: 91, scope: !64, inlinedAt: !91) +!110 = !DILocation(line: 481, column: 22, scope: !64, inlinedAt: !91) +!111 = !DILocation(line: 492, column: 79, scope: !64, inlinedAt: !91) +!112 = !DILocation(line: 492, column: 119, scope: !64, inlinedAt: !91) +!113 = !DILocation(line: 485, column: 23, scope: !64, inlinedAt: !91) +!114 = !DILocation(line: 484, column: 22, scope: !64, inlinedAt: !91) +!115 = !DILocation(line: 483, column: 23, scope: !64, inlinedAt: !91) +!116 = !DILocation(line: 487, column: 22, scope: !64, inlinedAt: !91) +!117 = !DILocation(line: 495, column: 25, scope: !64, inlinedAt: !91) +!118 = !DILocation(line: 513, column: 19, scope: !64, inlinedAt: !91) +!119 = !DILocation(line: 461, column: 14, scope: !64, inlinedAt: !91) +!120 = !DILocation(line: 486, column: 22, scope: !64, inlinedAt: !91) +!121 = !DILocation(line: 489, column: 23, scope: !64, inlinedAt: !91) +!122 = !DILocation(line: 494, column: 79, scope: !64, inlinedAt: !91) +!123 = !DILocation(line: 494, column: 91, scope: !64, inlinedAt: !91) +!124 = !DILocation(line: 494, column: 119, scope: !64, inlinedAt: !91) +!125 = !DILocation(line: 496, column: 24, scope: !64, inlinedAt: !91) +!126 = !DILocation(line: 497, column: 23, scope: !64, inlinedAt: !91) +!127 = !DILocation(line: 498, column: 23, scope: !64, inlinedAt: !91) +!128 = !DILocation(line: 503, column: 69, scope: !64, inlinedAt: !91) +!129 = !DILocation(line: 506, column: 27, scope: !64, inlinedAt: !91) +!130 = !DILocation(line: 507, column: 39, scope: !64, inlinedAt: !91) +!131 = !DILocation(line: 507, column: 21, scope: !64, inlinedAt: !91) +!132 = !DILocation(line: 512, column: 20, scope: !64, inlinedAt: !91) +!133 = !DILocation(line: 513, column: 14, scope: !64, inlinedAt: !91) +!134 = !DILocation(line: 533, column: 15, scope: !64, inlinedAt: !91) +!135 = !DILocation(line: 531, column: 43, scope: !64, inlinedAt: !91) +!136 = !DILocation(line: 535, column: 21, scope: !64, inlinedAt: !91) +!137 = !DILocation(line: 752, column: 33, scope: !64, inlinedAt: !91) +!138 = !DILocation(line: 753, column: 38, scope: !64, inlinedAt: !91) +!139 = !DILocation(line: 753, column: 24, scope: !64, inlinedAt: !91) +!140 = !DILocation(line: 754, column: 109, scope: !64, inlinedAt: !91) +!141 = !DILocation(line: 754, column: 113, scope: !64, inlinedAt: !91) +!142 = !DILocation(line: 754, column: 55, scope: !64, inlinedAt: !91) +!143 = !DILocation(line: 754, column: 25, scope: !64, inlinedAt: !91) +!144 = !DILocation(line: 755, column: 35, scope: !64, inlinedAt: !91) +!145 = !DILocation(line: 756, column: 34, scope: !64, inlinedAt: !91) +!146 = !DILocation(line: 756, column: 48, scope: !64, inlinedAt: !91) +!147 = !DILocation(line: 756, column: 63, scope: !64, inlinedAt: !91) +!148 = !DILocation(line: 757, column: 29, scope: !64, inlinedAt: !91) +!149 = !DILocation(line: 757, column: 61, scope: !64, inlinedAt: !91) +!150 = !DILocation(line: 757, column: 42, scope: !64, inlinedAt: !91) +!151 = !DILocation(line: 414, column: 28, scope: !64, inlinedAt: !91) +!152 = !DILocation(line: 214, column: 39, scope: !5) +!153 = !DILocation(line: 215, column: 31, scope: !5) +!154 = !DILocation(line: 215, column: 45, scope: !5) +!155 = !DILocation(line: 216, column: 62, scope: !5) +!156 = !DILocation(line: 216, column: 43, scope: !5) +!157 = !DILocation(line: 218, column: 33, scope: !5) +!158 = !DILocation(line: 390, column: 37, scope: !64, inlinedAt: !159) +!159 = !DILocation(line: 226, column: 16, scope: !5) +!160 = !DILocation(line: 390, column: 18, scope: !64, inlinedAt: !159) +!161 = !DILocation(line: 390, column: 49, scope: !64, inlinedAt: !159) +!162 = !DILocation(line: 391, column: 18, scope: !64, inlinedAt: !159) +!163 = !DILocation(line: 391, column: 49, scope: !64, inlinedAt: !159) +!164 = !DILocation(line: 395, column: 43, scope: !64, inlinedAt: !159) +!165 = !DILocation(line: 395, column: 63, scope: !64, inlinedAt: !159) +!166 = !DILocation(line: 397, column: 28, scope: !64, inlinedAt: !159) +!167 = !DILocation(line: 795, column: 52, scope: !64, inlinedAt: !159) +!168 = !DILocation(line: 795, column: 23, scope: !64, inlinedAt: !159) +!169 = !DILocation(line: 414, column: 19, scope: !64, inlinedAt: !159) +!170 = !DILocation(line: 415, column: 19, scope: !64, inlinedAt: !159) +!171 = !DILocation(line: 417, column: 19, scope: !64, inlinedAt: !159) +!172 = !DILocation(line: 459, column: 19, scope: !64, inlinedAt: !159) +!173 = !DILocation(line: 461, column: 14, scope: !64, inlinedAt: !159) +!174 = !DILocation(line: 506, column: 27, scope: !64, inlinedAt: !159) +!175 = !DILocation(line: 476, column: 79, scope: !64, inlinedAt: !159) +!176 = !DILocation(line: 507, column: 39, scope: !64, inlinedAt: !159) +!177 = !DILocation(line: 507, column: 21, scope: !64, inlinedAt: !159) +!178 = !DILocation(line: 512, column: 20, scope: !64, inlinedAt: !159) +!179 = !DILocation(line: 513, column: 19, scope: !64, inlinedAt: !159) +!180 = !DILocation(line: 513, column: 14, scope: !64, inlinedAt: !159) +!181 = !DILocation(line: 533, column: 15, scope: !64, inlinedAt: !159) +!182 = !DILocation(line: 520, column: 71, scope: !64, inlinedAt: !159) +!183 = !DILocation(line: 535, column: 21, scope: !64, inlinedAt: !159) +!184 = !DILocation(line: 752, column: 33, scope: !64, inlinedAt: !159) +!185 = !DILocation(line: 753, column: 38, scope: !64, inlinedAt: !159) +!186 = !DILocation(line: 753, column: 24, scope: !64, inlinedAt: !159) +!187 = !DILocation(line: 754, column: 109, scope: !64, inlinedAt: !159) +!188 = !DILocation(line: 754, column: 113, scope: !64, inlinedAt: !159) +!189 = !DILocation(line: 754, column: 55, scope: !64, inlinedAt: !159) +!190 = !DILocation(line: 754, column: 25, scope: !64, inlinedAt: !159) +!191 = !DILocation(line: 755, column: 35, scope: !64, inlinedAt: !159) +!192 = !DILocation(line: 756, column: 34, scope: !64, inlinedAt: !159) +!193 = !DILocation(line: 756, column: 48, scope: !64, inlinedAt: !159) +!194 = !DILocation(line: 756, column: 63, scope: !64, inlinedAt: !159) +!195 = !DILocation(line: 757, column: 29, scope: !64, inlinedAt: !159) +!196 = !DILocation(line: 757, column: 61, scope: !64, inlinedAt: !159) +!197 = !DILocation(line: 757, column: 42, scope: !64, inlinedAt: !159) +!198 = !DILocation(line: 414, column: 28, scope: !64, inlinedAt: !159) +!199 = !DILocation(line: 231, column: 24, scope: !5) +!200 = !DILocation(line: 231, column: 56, scope: !5) +!201 = !DILocation(line: 232, column: 14, scope: !5) +!202 = !DILocation(line: 236, column: 30, scope: !5) +!203 = !DILocation(line: 252, column: 25, scope: !5) +!204 = !DILocation(line: 253, column: 29, scope: !5) +!205 = !DILocation(line: 789, column: 38, scope: !64, inlinedAt: !206) +!206 = !DILocation(line: 256, column: 107, scope: !5) +!207 = !DILocation(line: 789, column: 20, scope: !64, inlinedAt: !206) +!208 = !DILocation(line: 789, column: 56, scope: !64, inlinedAt: !206) +!209 = !DILocation(line: 789, column: 49, scope: !64, inlinedAt: !206) +!210 = !DILocation(line: 797, column: 52, scope: !64, inlinedAt: !206) +!211 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !206) +!212 = !DILocation(line: 789, column: 20, scope: !64, inlinedAt: !213) +!213 = !DILocation(line: 257, column: 107, scope: !5) +!214 = !DILocation(line: 789, column: 49, scope: !64, inlinedAt: !213) +!215 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !213) +!216 = !DILocation(line: 263, column: 32, scope: !5) +!217 = !DILocation(line: 266, column: 56, scope: !5) +!218 = !DILocation(line: 267, column: 59, scope: !5) +!219 = !DILocation(line: 269, column: 34, scope: !5) +!220 = !DILocation(line: 286, column: 32, scope: !5) +!221 = !DILocation(line: 287, column: 30, scope: !5) +!222 = !DILocation(line: 287, column: 43, scope: !5) +!223 = !DILocation(line: 288, column: 55, scope: !5) +!224 = !DILocation(line: 288, column: 42, scope: !5) +!225 = !DILocation(line: 290, column: 45, scope: !5) +!226 = !DILocation(line: 290, column: 32, scope: !5) +!227 = !DILocation(line: 583, column: 37, scope: !64, inlinedAt: !228) +!228 = !DILocation(line: 298, column: 16, scope: !5) +!229 = !DILocation(line: 584, column: 38, scope: !64, inlinedAt: !228) +!230 = !DILocation(line: 590, column: 42, scope: !64, inlinedAt: !228) +!231 = !DILocation(line: 41, column: 22, scope: !18, inlinedAt: !228) +!232 = !DILocation(line: 41, column: 28, scope: !18, inlinedAt: !228) +!233 = !DILocation(line: 590, column: 98, scope: !64, inlinedAt: !228) +!234 = !DILocation(line: 590, column: 61, scope: !64, inlinedAt: !228) +!235 = !DILocation(line: 762, column: 21, scope: !64, inlinedAt: !228) +!236 = !DILocation(line: 692, column: 91, scope: !64, inlinedAt: !228) +!237 = !DILocation(line: 306, column: 41, scope: !5) +!238 = !DILocation(line: 307, column: 34, scope: !5) +!239 = !DILocation(line: 307, column: 47, scope: !5) +!240 = !DILocation(line: 308, column: 64, scope: !5) +!241 = !DILocation(line: 308, column: 46, scope: !5) +!242 = !DILocation(line: 310, column: 36, scope: !5) +!243 = !DILocation(line: 583, column: 37, scope: !64, inlinedAt: !244) +!244 = !DILocation(line: 318, column: 20, scope: !5) +!245 = !DILocation(line: 584, column: 38, scope: !64, inlinedAt: !244) +!246 = !DILocation(line: 590, column: 42, scope: !64, inlinedAt: !244) +!247 = !DILocation(line: 590, column: 61, scope: !64, inlinedAt: !244) +!248 = !DILocation(line: 658, column: 20, scope: !64, inlinedAt: !228) +!249 = !DILocation(line: 262, column: 30, scope: !5) +!250 = !DILocation(line: 263, column: 51, scope: !5) +!251 = !DILocation(line: 266, column: 44, scope: !5) +!252 = !DILocation(line: 266, column: 67, scope: !5) +!253 = !DILocation(line: 267, column: 36, scope: !5) +!254 = !DILocation(line: 267, column: 46, scope: !5) +!255 = !DILocation(line: 267, column: 70, scope: !5) +!256 = !DILocation(line: 269, column: 50, scope: !5) +!257 = !DILocation(line: 269, column: 60, scope: !5) +!258 = !DILocation(line: 271, column: 21, scope: !5) +!259 = !DILocation(line: 272, column: 23, scope: !5) +!260 = !DILocation(line: 275, column: 25, scope: !5) +!261 = !DILocation(line: 276, column: 29, scope: !5) +!262 = !DILocation(line: 583, column: 18, scope: !64, inlinedAt: !228) +!263 = !DILocation(line: 583, column: 49, scope: !64, inlinedAt: !228) +!264 = !DILocation(line: 584, column: 19, scope: !64, inlinedAt: !228) +!265 = !DILocation(line: 584, column: 51, scope: !64, inlinedAt: !228) +!266 = !DILocation(line: 795, column: 23, scope: !64, inlinedAt: !228) +!267 = !DILocation(line: 656, column: 28, scope: !64, inlinedAt: !228) +!268 = !DILocation(line: 656, column: 22, scope: !64, inlinedAt: !228) +!269 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !228) +!270 = !DILocation(line: 712, column: 29, scope: !64, inlinedAt: !228) +!271 = !DILocation(line: 712, column: 21, scope: !64, inlinedAt: !228) +!272 = !DILocation(line: 608, column: 19, scope: !64, inlinedAt: !228) +!273 = !DILocation(line: 609, column: 19, scope: !64, inlinedAt: !228) +!274 = !DILocation(line: 592, column: 28, scope: !64, inlinedAt: !228) +!275 = !DILocation(line: 795, column: 52, scope: !64, inlinedAt: !228) +!276 = !DILocation(line: 657, column: 26, scope: !64, inlinedAt: !228) +!277 = !DILocation(line: 657, column: 46, scope: !64, inlinedAt: !228) +!278 = !DILocation(line: 660, column: 15, scope: !64, inlinedAt: !228) +!279 = !DILocation(line: 679, column: 24, scope: !64, inlinedAt: !228) +!280 = !DILocation(line: 683, column: 25, scope: !64, inlinedAt: !228) +!281 = !DILocation(line: 681, column: 25, scope: !64, inlinedAt: !228) +!282 = !DILocation(line: 682, column: 24, scope: !64, inlinedAt: !228) +!283 = !DILocation(line: 690, column: 79, scope: !64, inlinedAt: !228) +!284 = !DILocation(line: 690, column: 91, scope: !64, inlinedAt: !228) +!285 = !DILocation(line: 690, column: 119, scope: !64, inlinedAt: !228) +!286 = !DILocation(line: 693, column: 25, scope: !64, inlinedAt: !228) +!287 = !DILocation(line: 694, column: 24, scope: !64, inlinedAt: !228) +!288 = !DILocation(line: 696, column: 24, scope: !64, inlinedAt: !228) +!289 = !DILocation(line: 700, column: 69, scope: !64, inlinedAt: !228) +!290 = !DILocation(line: 703, column: 27, scope: !64, inlinedAt: !228) +!291 = !DILocation(line: 704, column: 40, scope: !64, inlinedAt: !228) +!292 = !DILocation(line: 704, column: 22, scope: !64, inlinedAt: !228) +!293 = !DILocation(line: 708, column: 24, scope: !64, inlinedAt: !228) +!294 = !DILocation(line: 708, column: 43, scope: !64, inlinedAt: !228) +!295 = !DILocation(line: 714, column: 20, scope: !64, inlinedAt: !228) +!296 = !DILocation(line: 715, column: 22, scope: !64, inlinedAt: !228) +!297 = !DILocation(line: 715, column: 16, scope: !64, inlinedAt: !228) +!298 = !DILocation(line: 739, column: 24, scope: !64, inlinedAt: !228) +!299 = !DILocation(line: 737, column: 45, scope: !64, inlinedAt: !228) +!300 = !DILocation(line: 739, column: 43, scope: !64, inlinedAt: !228) +!301 = !DILocation(line: 610, column: 19, scope: !64, inlinedAt: !228) +!302 = !DILocation(line: 752, column: 33, scope: !64, inlinedAt: !228) +!303 = !DILocation(line: 753, column: 38, scope: !64, inlinedAt: !228) +!304 = !DILocation(line: 753, column: 24, scope: !64, inlinedAt: !228) +!305 = !DILocation(line: 754, column: 109, scope: !64, inlinedAt: !228) +!306 = !DILocation(line: 754, column: 113, scope: !64, inlinedAt: !228) +!307 = !DILocation(line: 754, column: 55, scope: !64, inlinedAt: !228) +!308 = !DILocation(line: 754, column: 25, scope: !64, inlinedAt: !228) +!309 = !DILocation(line: 755, column: 35, scope: !64, inlinedAt: !228) +!310 = !DILocation(line: 756, column: 34, scope: !64, inlinedAt: !228) +!311 = !DILocation(line: 756, column: 48, scope: !64, inlinedAt: !228) +!312 = !DILocation(line: 756, column: 63, scope: !64, inlinedAt: !228) +!313 = !DILocation(line: 757, column: 29, scope: !64, inlinedAt: !228) +!314 = !DILocation(line: 757, column: 61, scope: !64, inlinedAt: !228) +!315 = !DILocation(line: 757, column: 42, scope: !64, inlinedAt: !228) +!316 = !DILocation(line: 608, column: 28, scope: !64, inlinedAt: !228) +!317 = !DILocation(line: 609, column: 28, scope: !64, inlinedAt: !228) +!318 = !DILocation(line: 656, column: 52, scope: !64, inlinedAt: !228) +!319 = !DILocation(line: 797, column: 52, scope: !64, inlinedAt: !228) +!320 = !DILocation(line: 583, column: 18, scope: !64, inlinedAt: !244) +!321 = !DILocation(line: 583, column: 49, scope: !64, inlinedAt: !244) +!322 = !DILocation(line: 584, column: 19, scope: !64, inlinedAt: !244) +!323 = !DILocation(line: 584, column: 51, scope: !64, inlinedAt: !244) +!324 = !DILocation(line: 795, column: 23, scope: !64, inlinedAt: !244) +!325 = !DILocation(line: 656, column: 28, scope: !64, inlinedAt: !244) +!326 = !DILocation(line: 656, column: 22, scope: !64, inlinedAt: !244) +!327 = !DILocation(line: 797, column: 23, scope: !64, inlinedAt: !244) +!328 = !DILocation(line: 712, column: 29, scope: !64, inlinedAt: !244) +!329 = !DILocation(line: 712, column: 21, scope: !64, inlinedAt: !244) +!330 = !DILocation(line: 608, column: 19, scope: !64, inlinedAt: !244) +!331 = !DILocation(line: 609, column: 19, scope: !64, inlinedAt: !244) +!332 = !DILocation(line: 592, column: 28, scope: !64, inlinedAt: !244) +!333 = !DILocation(line: 795, column: 52, scope: !64, inlinedAt: !244) +!334 = !DILocation(line: 657, column: 26, scope: !64, inlinedAt: !244) +!335 = !DILocation(line: 657, column: 46, scope: !64, inlinedAt: !244) +!336 = !DILocation(line: 658, column: 20, scope: !64, inlinedAt: !244) +!337 = !DILocation(line: 660, column: 15, scope: !64, inlinedAt: !244) +!338 = !DILocation(line: 703, column: 27, scope: !64, inlinedAt: !244) +!339 = !DILocation(line: 674, column: 78, scope: !64, inlinedAt: !244) +!340 = !DILocation(line: 704, column: 40, scope: !64, inlinedAt: !244) +!341 = !DILocation(line: 704, column: 22, scope: !64, inlinedAt: !244) +!342 = !DILocation(line: 708, column: 24, scope: !64, inlinedAt: !244) +!343 = !DILocation(line: 708, column: 43, scope: !64, inlinedAt: !244) +!344 = !DILocation(line: 714, column: 20, scope: !64, inlinedAt: !244) +!345 = !DILocation(line: 715, column: 22, scope: !64, inlinedAt: !244) +!346 = !DILocation(line: 715, column: 16, scope: !64, inlinedAt: !244) +!347 = !DILocation(line: 739, column: 24, scope: !64, inlinedAt: !244) +!348 = !DILocation(line: 723, column: 70, scope: !64, inlinedAt: !244) +!349 = !DILocation(line: 739, column: 43, scope: !64, inlinedAt: !244) +!350 = !DILocation(line: 752, column: 33, scope: !64, inlinedAt: !244) +!351 = !DILocation(line: 753, column: 38, scope: !64, inlinedAt: !244) +!352 = !DILocation(line: 753, column: 24, scope: !64, inlinedAt: !244) +!353 = !DILocation(line: 754, column: 109, scope: !64, inlinedAt: !244) +!354 = !DILocation(line: 754, column: 113, scope: !64, inlinedAt: !244) +!355 = !DILocation(line: 754, column: 55, scope: !64, inlinedAt: !244) +!356 = !DILocation(line: 754, column: 25, scope: !64, inlinedAt: !244) +!357 = !DILocation(line: 755, column: 35, scope: !64, inlinedAt: !244) +!358 = !DILocation(line: 756, column: 34, scope: !64, inlinedAt: !244) +!359 = !DILocation(line: 756, column: 48, scope: !64, inlinedAt: !244) +!360 = !DILocation(line: 756, column: 63, scope: !64, inlinedAt: !244) +!361 = !DILocation(line: 757, column: 29, scope: !64, inlinedAt: !244) +!362 = !DILocation(line: 757, column: 61, scope: !64, inlinedAt: !244) +!363 = !DILocation(line: 757, column: 42, scope: !64, inlinedAt: !244) +!364 = !DILocation(line: 608, column: 28, scope: !64, inlinedAt: !244) +!365 = !DILocation(line: 609, column: 28, scope: !64, inlinedAt: !244) +!366 = !DILocation(line: 610, column: 19, scope: !64, inlinedAt: !244) +!367 = !DILocation(line: 656, column: 52, scope: !64, inlinedAt: !244) +!368 = !DILocation(line: 797, column: 52, scope: !64, inlinedAt: !244) +!369 = !DILocation(line: 323, column: 23, scope: !5) +!370 = !DILocation(line: 323, column: 55, scope: !5) +!371 = !DILocation(line: 332, column: 30, scope: !5) +!372 = !DILocation(line: 334, column: 14, scope: !5) +!373 = !DILocation(line: 345, column: 55, scope: !5) +!374 = !DILocation(line: 345, column: 69, scope: !5) +!375 = !DILocation(line: 345, column: 29, scope: !5) +!376 = !DILocation(line: 345, column: 99, scope: !5) +!377 = !DILocation(line: 139, column: 4, scope: !5) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ptx b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ptx new file mode 100644 index 0000000000000000000000000000000000000000..ac66abccef00965c7d388fc1f203562a5e6fbca6 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ptx @@ -0,0 +1,9372 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_tem_fused_mul_1 // -- Begin function triton_tem_fused_mul_1 +.extern .shared .align 16 .b8 global_smem[]; +.global .align 1 .b8 _$_str[11] = {95, 95, 67, 85, 68, 65, 95, 70, 84, 90}; + // @triton_tem_fused_mul_1 +.visible .entry triton_tem_fused_mul_1( + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_0, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_1, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_2, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_3, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_4, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_5, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_6, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_7, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_8, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_9, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_10, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_11, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_12, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_13, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_14, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_15, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_16, + .param .u32 triton_tem_fused_mul_1_param_17, + .param .u32 triton_tem_fused_mul_1_param_18, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_19, + .param .u64 .ptr .global .align 1 triton_tem_fused_mul_1_param_20 +) +.reqntid 256 +{ + .reg .pred %p<1269>; + .reg .b16 %rs<371>; + .reg .b32 %r<15236>; + .reg .b64 %rd<1137>; + .loc 1 18 0 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:18:0 +$L__func_begin0: + .loc 1 18 0 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:18:0 + +// %bb.0: + ld.param.b32 %r2339, [triton_tem_fused_mul_1_param_18]; + ld.param.b32 %r2338, [triton_tem_fused_mul_1_param_17]; + ld.param.b64 %rd181, [triton_tem_fused_mul_1_param_5]; + ld.param.b64 %rd180, [triton_tem_fused_mul_1_param_4]; + ld.param.b64 %rd179, [triton_tem_fused_mul_1_param_3]; + ld.param.b64 %rd178, [triton_tem_fused_mul_1_param_0]; +$L__tmp0: + .loc 1 94 54 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:94:54 + shl.b32 %r1, %r2338, 12; + ld.param.b64 %rd192, [triton_tem_fused_mul_1_param_1]; + .loc 1 97 74 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:74 + setp.lt.s32 %p9, %r2338, 2; + ld.param.b64 %rd193, [triton_tem_fused_mul_1_param_2]; + .loc 1 97 66 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:66 + selp.b32 %r2340, 1, 0, %p9; + .loc 1 97 100 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:100 + setp.gt.s32 %p10, %r2338, 1; + .loc 1 97 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:91 + selp.b32 %r2341, %r2338, 0, %p10; + .loc 1 97 82 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:82 + add.s32 %r2342, %r2341, %r2340; + .loc 1 97 59 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:59 + shl.b32 %r2, %r2342, 12; + .loc 1 97 111 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:97:111 + shl.b32 %r3, %r2342, 7; + .loc 1 111 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:111:24 + mov.u32 %r4, %ctaid.x; +$L__tmp1: + .loc 2 41 22 // standard.py:41:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:112:36 ] + add.s32 %r2343, %r2339, 127; + .loc 2 41 28 // standard.py:41:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:112:36 ] + shr.s32 %r2344, %r2343, 31; + shr.u32 %r2345, %r2344, 25; + add.s32 %r2346, %r2343, %r2345; + shr.s32 %r5, %r2346, 7; +$L__tmp2: + .loc 1 115 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:115:27 + mov.u32 %r6, %ctaid.y; + .loc 1 116 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:116:28 + mov.u32 %r7, %ctaid.z; + .loc 1 124 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:124:25 + shl.b32 %r8, %r7, 7; + .loc 1 131 9 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:131:9 + mul.wide.u32 %rd195, %r8, 2; + add.s64 %rd1, %rd192, %rd195; + .loc 1 132 9 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:132:9 + add.s64 %rd2, %rd193, %rd195; + .loc 1 136 26 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:136:26 + mov.u32 %r9, %tid.x; + shr.u32 %r10, %r9, 5; + and.b32 %r11, %r9, 240; + bfe.u32 %r12, %r9, 4, 4; + or.b32 %r13, %r12, 16; + or.b32 %r14, %r12, 32; + or.b32 %r15, %r12, 48; + or.b32 %r16, %r12, 64; + or.b32 %r17, %r12, 80; + or.b32 %r18, %r12, 96; + or.b32 %r19, %r12, 112; + shr.u32 %r2350, %r9, 1; + and.b32 %r2351, %r2350, 112; + bfe.u32 %r2352, %r9, 2, 3; + or.b32 %r20, %r2351, %r2352; + or.b32 %r21, %r20, 8; + .loc 1 139 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:139:14 + setp.lt.s32 %p11, %r4, %r5; + .loc 1 139 7 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:139:7 + @%p11 bra $L__BB0_8; + bra.uni $L__BB0_1; +$L__BB0_8: + .loc 1 0 7 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:7 + ld.param.b64 %rd191, [triton_tem_fused_mul_1_param_16]; + ld.param.b64 %rd190, [triton_tem_fused_mul_1_param_15]; + ld.param.b64 %rd189, [triton_tem_fused_mul_1_param_14]; + ld.param.b64 %rd186, [triton_tem_fused_mul_1_param_11]; + ld.param.b64 %rd185, [triton_tem_fused_mul_1_param_10]; + ld.param.b64 %rd194, [triton_tem_fused_mul_1_param_7]; + mul.lo.s32 %r2347, %r6, %r2339; + shl.b32 %r2348, %r2347, 10; + add.s32 %r2349, %r2348, %r8; + mad.wide.s32 %rd3, %r2349, 2, %rd194; + .loc 1 252 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:252:25 + shl.b32 %r7370, %r4, 7; + .loc 1 253 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:253:29 + or.b32 %r679, %r12, %r7370; + or.b32 %r680, %r13, %r7370; + or.b32 %r681, %r14, %r7370; + or.b32 %r682, %r15, %r7370; + or.b32 %r683, %r16, %r7370; + or.b32 %r684, %r17, %r7370; + or.b32 %r685, %r18, %r7370; + or.b32 %r686, %r19, %r7370; +$L__tmp3: + .loc 1 789 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + shl.b32 %r7371, %r679, 10; + shl.b32 %r7372, %r680, 10; + shl.b32 %r7373, %r681, 10; + shl.b32 %r7374, %r682, 10; + shl.b32 %r7375, %r683, 10; + shl.b32 %r7376, %r684, 10; + shl.b32 %r7377, %r685, 10; + shl.b32 %r7378, %r686, 10; + .loc 1 789 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + cvt.s64.s32 %rd67, %r7371; + mul.wide.s32 %rd527, %r7371, 2; + add.s64 %rd528, %rd1, %rd527; + cvt.s64.s32 %rd68, %r7372; + mul.wide.s32 %rd529, %r7372, 2; + add.s64 %rd530, %rd1, %rd529; + cvt.s64.s32 %rd69, %r7373; + mul.wide.s32 %rd531, %r7373, 2; + add.s64 %rd532, %rd1, %rd531; + cvt.s64.s32 %rd70, %r7374; + mul.wide.s32 %rd533, %r7374, 2; + add.s64 %rd534, %rd1, %rd533; + cvt.s64.s32 %rd71, %r7375; + mul.wide.s32 %rd535, %r7375, 2; + add.s64 %rd536, %rd1, %rd535; + cvt.s64.s32 %rd72, %r7376; + mul.wide.s32 %rd537, %r7376, 2; + add.s64 %rd538, %rd1, %rd537; + cvt.s64.s32 %rd73, %r7377; + mul.wide.s32 %rd539, %r7377, 2; + add.s64 %rd540, %rd1, %rd539; + cvt.s64.s32 %rd74, %r7378; + mul.wide.s32 %rd541, %r7378, 2; + add.s64 %rd542, %rd1, %rd541; + .loc 1 789 56 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:56 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + shl.b32 %r7379, %r9, 3; + and.b32 %r7380, %r7379, 120; + .loc 1 789 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + cvt.u64.u32 %rd75, %r7380; + mul.wide.u32 %rd543, %r7380, 2; + add.s64 %rd506, %rd528, %rd543; + add.s64 %rd507, %rd530, %rd543; + add.s64 %rd508, %rd532, %rd543; + add.s64 %rd509, %rd534, %rd543; + add.s64 %rd510, %rd536, %rd543; + add.s64 %rd511, %rd538, %rd543; + add.s64 %rd512, %rd540, %rd543; + add.s64 %rd513, %rd542, %rd543; + .loc 1 797 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + setp.lt.s32 %p1253, %r679, %r2339; + setp.lt.s32 %p1254, %r680, %r2339; + setp.lt.s32 %p1255, %r681, %r2339; + setp.lt.s32 %p1256, %r682, %r2339; + setp.lt.s32 %p1257, %r683, %r2339; + setp.lt.s32 %p1258, %r684, %r2339; + setp.lt.s32 %p1259, %r685, %r2339; + setp.lt.s32 %p1260, %r686, %r2339; + mov.b32 %r7241, 0; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:256:107 ] + // begin inline asm + mov.u32 %r7237, %r7241; + mov.u32 %r7238, %r7241; + mov.u32 %r7239, %r7241; + mov.u32 %r7240, %r7241; + @%p1253 ld.global.v4.b32 { %r7237, %r7238, %r7239, %r7240 }, [ %rd506 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7245, %r7241; + mov.u32 %r7246, %r7241; + mov.u32 %r7247, %r7241; + mov.u32 %r7248, %r7241; + @%p1254 ld.global.v4.b32 { %r7245, %r7246, %r7247, %r7248 }, [ %rd507 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7253, %r7241; + mov.u32 %r7254, %r7241; + mov.u32 %r7255, %r7241; + mov.u32 %r7256, %r7241; + @%p1255 ld.global.v4.b32 { %r7253, %r7254, %r7255, %r7256 }, [ %rd508 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7261, %r7241; + mov.u32 %r7262, %r7241; + mov.u32 %r7263, %r7241; + mov.u32 %r7264, %r7241; + @%p1256 ld.global.v4.b32 { %r7261, %r7262, %r7263, %r7264 }, [ %rd509 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7269, %r7241; + mov.u32 %r7270, %r7241; + mov.u32 %r7271, %r7241; + mov.u32 %r7272, %r7241; + @%p1257 ld.global.v4.b32 { %r7269, %r7270, %r7271, %r7272 }, [ %rd510 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7277, %r7241; + mov.u32 %r7278, %r7241; + mov.u32 %r7279, %r7241; + mov.u32 %r7280, %r7241; + @%p1258 ld.global.v4.b32 { %r7277, %r7278, %r7279, %r7280 }, [ %rd511 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7285, %r7241; + mov.u32 %r7286, %r7241; + mov.u32 %r7287, %r7241; + mov.u32 %r7288, %r7241; + @%p1259 ld.global.v4.b32 { %r7285, %r7286, %r7287, %r7288 }, [ %rd512 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7293, %r7241; + mov.u32 %r7294, %r7241; + mov.u32 %r7295, %r7241; + mov.u32 %r7296, %r7241; + @%p1260 ld.global.v4.b32 { %r7293, %r7294, %r7295, %r7296 }, [ %rd513 + 0 ]; + // end inline asm + shl.b32 %r7381, %r9, 4; + and.b32 %r7382, %r7381, 112; + shl.b32 %r7383, %r11, 3; + and.b32 %r7384, %r9, 112; + and.b32 %r7385, %r9, 8; + shl.b32 %r7386, %r7385, 11; + or.b32 %r7387, %r7382, %r7383; + xor.b32 %r7388, %r7387, %r7384; + or.b32 %r7389, %r7388, %r7386; + mov.b32 %r7390, global_smem; + add.s32 %r7391, %r7390, %r7389; + st.shared.v4.b32 [%r7391+99328], {%r7237, %r7238, %r7239, %r7240}; + st.shared.v4.b32 [%r7391+101376], {%r7245, %r7246, %r7247, %r7248}; + st.shared.v4.b32 [%r7391+103424], {%r7253, %r7254, %r7255, %r7256}; + st.shared.v4.b32 [%r7391+105472], {%r7261, %r7262, %r7263, %r7264}; + st.shared.v4.b32 [%r7391+107520], {%r7269, %r7270, %r7271, %r7272}; + st.shared.v4.b32 [%r7391+109568], {%r7277, %r7278, %r7279, %r7280}; + st.shared.v4.b32 [%r7391+111616], {%r7285, %r7286, %r7287, %r7288}; + st.shared.v4.b32 [%r7391+113664], {%r7293, %r7294, %r7295, %r7296}; +$L__tmp4: + .loc 1 789 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:257:107 ] + add.s64 %rd544, %rd2, %rd527; + add.s64 %rd545, %rd2, %rd529; + add.s64 %rd546, %rd2, %rd531; + add.s64 %rd547, %rd2, %rd533; + add.s64 %rd548, %rd2, %rd535; + add.s64 %rd549, %rd2, %rd537; + add.s64 %rd550, %rd2, %rd539; + add.s64 %rd551, %rd2, %rd541; + .loc 1 789 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:257:107 ] + add.s64 %rd514, %rd544, %rd543; + add.s64 %rd515, %rd545, %rd543; + add.s64 %rd516, %rd546, %rd543; + add.s64 %rd517, %rd547, %rd543; + add.s64 %rd518, %rd548, %rd543; + add.s64 %rd519, %rd549, %rd543; + add.s64 %rd520, %rd550, %rd543; + add.s64 %rd521, %rd551, %rd543; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:257:107 ] + // begin inline asm + mov.u32 %r7301, %r7241; + mov.u32 %r7302, %r7241; + mov.u32 %r7303, %r7241; + mov.u32 %r7304, %r7241; + @%p1253 ld.global.v4.b32 { %r7301, %r7302, %r7303, %r7304 }, [ %rd514 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7309, %r7241; + mov.u32 %r7310, %r7241; + mov.u32 %r7311, %r7241; + mov.u32 %r7312, %r7241; + @%p1254 ld.global.v4.b32 { %r7309, %r7310, %r7311, %r7312 }, [ %rd515 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7317, %r7241; + mov.u32 %r7318, %r7241; + mov.u32 %r7319, %r7241; + mov.u32 %r7320, %r7241; + @%p1255 ld.global.v4.b32 { %r7317, %r7318, %r7319, %r7320 }, [ %rd516 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7325, %r7241; + mov.u32 %r7326, %r7241; + mov.u32 %r7327, %r7241; + mov.u32 %r7328, %r7241; + @%p1256 ld.global.v4.b32 { %r7325, %r7326, %r7327, %r7328 }, [ %rd517 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7333, %r7241; + mov.u32 %r7334, %r7241; + mov.u32 %r7335, %r7241; + mov.u32 %r7336, %r7241; + @%p1257 ld.global.v4.b32 { %r7333, %r7334, %r7335, %r7336 }, [ %rd518 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7341, %r7241; + mov.u32 %r7342, %r7241; + mov.u32 %r7343, %r7241; + mov.u32 %r7344, %r7241; + @%p1258 ld.global.v4.b32 { %r7341, %r7342, %r7343, %r7344 }, [ %rd519 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7349, %r7241; + mov.u32 %r7350, %r7241; + mov.u32 %r7351, %r7241; + mov.u32 %r7352, %r7241; + @%p1259 ld.global.v4.b32 { %r7349, %r7350, %r7351, %r7352 }, [ %rd520 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r7357, %r7241; + mov.u32 %r7358, %r7241; + mov.u32 %r7359, %r7241; + mov.u32 %r7360, %r7241; + @%p1260 ld.global.v4.b32 { %r7357, %r7358, %r7359, %r7360 }, [ %rd521 + 0 ]; + // end inline asm + st.shared.v4.b32 [%r7391+132096], {%r7301, %r7302, %r7303, %r7304}; + st.shared.v4.b32 [%r7391+134144], {%r7309, %r7310, %r7311, %r7312}; + st.shared.v4.b32 [%r7391+136192], {%r7317, %r7318, %r7319, %r7320}; + st.shared.v4.b32 [%r7391+138240], {%r7325, %r7326, %r7327, %r7328}; + st.shared.v4.b32 [%r7391+140288], {%r7333, %r7334, %r7335, %r7336}; + st.shared.v4.b32 [%r7391+142336], {%r7341, %r7342, %r7343, %r7344}; + st.shared.v4.b32 [%r7391+144384], {%r7349, %r7350, %r7351, %r7352}; + st.shared.v4.b32 [%r7391+146432], {%r7357, %r7358, %r7359, %r7360}; +$L__tmp5: + .loc 1 266 56 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:266:56 + mul.lo.s32 %r687, %r1, %r6; + .loc 1 267 59 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:267:59 + mul.lo.s32 %r688, %r2, %r6; + .loc 1 269 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:269:34 + shl.b32 %r689, %r6, 5; + .loc 1 286 32 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:286:32 + mul.wide.u32 %rd552, %r4, 4; + add.s64 %rd522, %rd186, %rd552; + mov.pred %p510, -1; + .loc 1 287 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:287:30 + // begin inline asm + mov.u32 %r7365, 0x0; + @%p510 ld.global.b32 { %r7365 }, [ %rd522 + 0 ]; + // end inline asm + .loc 1 287 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:287:43 + shl.b32 %r7392, %r7365, 7; + .loc 1 288 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:288:55 + add.s64 %rd523, %rd185, %rd552; + .loc 1 288 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:288:42 + // begin inline asm + mov.u32 %r7366, 0x0; + @%p510 ld.global.b32 { %r7366 }, [ %rd523 + 0 ]; + // end inline asm + .loc 1 290 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:290:45 + and.b32 %r691, %r9, 3; + shl.b32 %r692, %r691, 1; + or.b32 %r7393, %r692, 1; + or.b32 %r7394, %r692, 8; + or.b32 %r7395, %r692, 9; + .loc 1 290 32 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:290:32 + or.b32 %r7396, %r7392, %r12; + or.b32 %r7397, %r7392, %r13; + or.b32 %r7398, %r7392, %r14; + or.b32 %r7399, %r7392, %r15; +$L__tmp6: + .loc 1 583 37 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:37 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r7400, %r7396, 12; + shl.b32 %r7401, %r7397, 12; + shl.b32 %r7402, %r7398, 12; + shl.b32 %r7403, %r7399, 12; + .loc 1 584 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r7404, %r7396, 7; + shl.b32 %r7405, %r7397, 7; + shl.b32 %r7406, %r7398, 7; + shl.b32 %r7407, %r7399, 7; + .loc 1 590 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:590:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r693, %r7366, 1; + .loc 2 41 22 // standard.py:41:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r7408, %r2338, 63; + .loc 2 41 28 // standard.py:41:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.s32 %r7409, %r7408, 31; + shr.u32 %r7410, %r7409, 26; + add.s32 %r7411, %r7408, %r7410; + shr.s32 %r7412, %r7411, 6; + .loc 1 590 98 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:590:98 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + max.s32 %r7413, %r7412, 1; + .loc 1 590 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:590:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + min.s32 %r7414, %r693, %r7413; +$L__tmp7: + .loc 1 253 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:253:29 + or.b32 %r7415, %r20, %r7370; + or.b32 %r7416, %r21, %r7370; + .loc 1 290 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:290:45 + or.b32 %r7417, %r692, 16; + or.b32 %r7418, %r692, 17; + or.b32 %r7419, %r692, 24; + or.b32 %r7420, %r692, 25; + or.b32 %r7421, %r692, 32; + or.b32 %r7422, %r692, 33; + or.b32 %r7423, %r692, 40; + or.b32 %r7424, %r692, 41; + or.b32 %r7425, %r692, 48; + or.b32 %r7426, %r692, 49; + or.b32 %r7427, %r692, 56; + or.b32 %r7428, %r692, 57; + .loc 1 290 32 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:290:32 + or.b32 %r785, %r7392, %r7428; + or.b32 %r784, %r7392, %r7427; + or.b32 %r783, %r7392, %r7426; + or.b32 %r782, %r7392, %r7425; + or.b32 %r781, %r7392, %r7424; + or.b32 %r780, %r7392, %r7423; + or.b32 %r779, %r7392, %r7422; + or.b32 %r778, %r7392, %r7421; + or.b32 %r777, %r7392, %r7420; + or.b32 %r776, %r7392, %r7419; + or.b32 %r775, %r7392, %r7418; + or.b32 %r774, %r7392, %r7417; + or.b32 %r773, %r7392, %r7395; + or.b32 %r772, %r7392, %r7394; + or.b32 %r771, %r7392, %r7393; + or.b32 %r770, %r7392, %r692; +$L__tmp8: + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + rem.s32 %r1003, %r7416, %r2339; + rem.s32 %r1002, %r7415, %r2339; + .loc 1 692 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:692:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.u32 %r741, %r1002, 4; + shr.u32 %r739, %r1003, 4; +$L__tmp9: + .loc 1 306 41 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:306:41 + add.s64 %rd524, %rd190, %rd552; + .loc 1 307 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:307:34 + // begin inline asm + mov.u32 %r7367, 0x0; + @%p510 ld.global.b32 { %r7367 }, [ %rd524 + 0 ]; + // end inline asm + .loc 1 307 47 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:307:47 + shl.b32 %r742, %r7367, 7; + .loc 1 308 64 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:308:64 + add.s64 %rd525, %rd189, %rd552; + .loc 1 308 46 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:308:46 + // begin inline asm + mov.u32 %r7368, 0x0; + @%p510 ld.global.b32 { %r7368 }, [ %rd525 + 0 ]; + // end inline asm + .loc 1 310 36 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:310:36 + or.b32 %r744, %r742, %r692; + or.b32 %r745, %r742, %r7393; + or.b32 %r746, %r742, %r7394; + or.b32 %r747, %r742, %r7395; + or.b32 %r748, %r742, %r7417; + or.b32 %r749, %r742, %r7418; + or.b32 %r750, %r742, %r7419; + or.b32 %r751, %r742, %r7420; + or.b32 %r752, %r742, %r7421; + or.b32 %r753, %r742, %r7422; + or.b32 %r754, %r742, %r7423; + or.b32 %r755, %r742, %r7424; + or.b32 %r756, %r742, %r7425; + or.b32 %r757, %r742, %r7426; + or.b32 %r758, %r742, %r7427; + or.b32 %r759, %r742, %r7428; + or.b32 %r7429, %r742, %r12; + or.b32 %r7430, %r742, %r13; + or.b32 %r7431, %r742, %r14; + or.b32 %r7432, %r742, %r15; +$L__tmp10: + .loc 1 583 37 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:37 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r7433, %r7429, 12; + shl.b32 %r7434, %r7430, 12; + shl.b32 %r7435, %r7431, 12; + shl.b32 %r7436, %r7432, 12; + .loc 1 584 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r7437, %r7429, 7; + shl.b32 %r7438, %r7430, 7; + shl.b32 %r7439, %r7431, 7; + shl.b32 %r7440, %r7432, 7; + .loc 1 590 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:590:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r760, %r7368, 1; + .loc 1 590 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:590:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + min.s32 %r7441, %r760, %r7413; +$L__tmp11: + .loc 1 658 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:658:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + fence.proxy.async.shared::cta; + // end inline asm + cvt.s64.s32 %rd78, %r7400; + cvt.s64.s32 %rd79, %r7401; + cvt.s64.s32 %rd80, %r7402; + cvt.s64.s32 %rd81, %r7403; + cvt.s64.s32 %rd82, %r7404; + cvt.s64.s32 %rd83, %r7405; + cvt.s64.s32 %rd84, %r7406; + cvt.s64.s32 %rd85, %r7407; + setp.gt.s32 %p514, %r693, 0; + setp.lt.s32 %p515, %r7396, %r2338; + setp.lt.s32 %p516, %r7397, %r2338; + setp.lt.s32 %p517, %r7398, %r2338; + setp.lt.s32 %p518, %r7399, %r2338; + shl.b32 %r7442, %r7385, 10; + or.b32 %r761, %r7388, %r7442; + add.s32 %r10792, %r7390, %r761; + selp.b32 %r7443, 16, 0, %p515; + selp.b32 %r7525, %r7443, 0, %p514; + add.s32 %r10794, %r10792, 2048; + selp.b32 %r7444, 16, 0, %p516; + selp.b32 %r7527, %r7444, 0, %p514; + add.s32 %r10796, %r10792, 4096; + selp.b32 %r7445, 16, 0, %p517; + selp.b32 %r7529, %r7445, 0, %p514; + add.s32 %r10798, %r10792, 6144; + selp.b32 %r7446, 16, 0, %p518; + selp.b32 %r7531, %r7446, 0, %p514; + setp.lt.s32 %p519, %r770, %r2338; + setp.lt.s32 %p520, %r771, %r2338; + setp.lt.s32 %p521, %r772, %r2338; + setp.lt.s32 %p522, %r773, %r2338; + setp.lt.s32 %p523, %r774, %r2338; + setp.lt.s32 %p524, %r775, %r2338; + setp.lt.s32 %p525, %r776, %r2338; + setp.lt.s32 %p526, %r777, %r2338; + setp.lt.s32 %p527, %r778, %r2338; + setp.lt.s32 %p528, %r779, %r2338; + setp.lt.s32 %p529, %r780, %r2338; + setp.lt.s32 %p530, %r781, %r2338; + setp.lt.s32 %p531, %r782, %r2338; + setp.lt.s32 %p532, %r783, %r2338; + setp.lt.s32 %p533, %r784, %r2338; + setp.lt.s32 %p534, %r785, %r2338; + cvt.s64.s32 %rd86, %r770; + cvt.s64.s32 %rd87, %r771; + cvt.s64.s32 %rd88, %r772; + cvt.s64.s32 %rd89, %r773; + cvt.s64.s32 %rd90, %r774; + cvt.s64.s32 %rd91, %r775; + cvt.s64.s32 %rd92, %r776; + cvt.s64.s32 %rd93, %r777; + cvt.s64.s32 %rd94, %r778; + cvt.s64.s32 %rd95, %r779; + cvt.s64.s32 %rd96, %r780; + cvt.s64.s32 %rd97, %r781; + cvt.s64.s32 %rd98, %r782; + cvt.s64.s32 %rd99, %r783; + cvt.s64.s32 %rd100, %r784; + cvt.s64.s32 %rd101, %r785; + and.b32 %r786, %r9, 252; + shl.b32 %r787, %r691, 3; + add.s32 %r7447, %r7390, %r787; + add.s32 %r10800, %r7447, 98304; + selp.b32 %r7448, 4, 0, %p519; + selp.b32 %r7533, %r7448, 0, %p514; + add.s32 %r10802, %r7447, 98308; + selp.b32 %r7449, 4, 0, %p520; + selp.b32 %r7535, %r7449, 0, %p514; + add.s32 %r10804, %r7447, 98336; + selp.b32 %r7450, 4, 0, %p521; + selp.b32 %r7537, %r7450, 0, %p514; + add.s32 %r10806, %r7447, 98340; + selp.b32 %r7451, 4, 0, %p522; + selp.b32 %r7539, %r7451, 0, %p514; + add.s32 %r10808, %r7447, 98368; + selp.b32 %r7452, 4, 0, %p523; + selp.b32 %r7541, %r7452, 0, %p514; + add.s32 %r10810, %r7447, 98372; + selp.b32 %r7453, 4, 0, %p524; + selp.b32 %r7543, %r7453, 0, %p514; + add.s32 %r10812, %r7447, 98400; + selp.b32 %r7454, 4, 0, %p525; + selp.b32 %r7545, %r7454, 0, %p514; + add.s32 %r10814, %r7447, 98404; + selp.b32 %r7455, 4, 0, %p526; + selp.b32 %r7547, %r7455, 0, %p514; + add.s32 %r10816, %r7447, 98432; + selp.b32 %r7456, 4, 0, %p527; + selp.b32 %r7549, %r7456, 0, %p514; + add.s32 %r10818, %r7447, 98436; + selp.b32 %r7457, 4, 0, %p528; + selp.b32 %r7551, %r7457, 0, %p514; + add.s32 %r10820, %r7447, 98464; + selp.b32 %r7458, 4, 0, %p529; + selp.b32 %r7553, %r7458, 0, %p514; + add.s32 %r10822, %r7447, 98468; + selp.b32 %r7459, 4, 0, %p530; + selp.b32 %r7555, %r7459, 0, %p514; + add.s32 %r10824, %r7447, 98496; + selp.b32 %r7460, 4, 0, %p531; + selp.b32 %r7557, %r7460, 0, %p514; + add.s32 %r10826, %r7447, 98500; + selp.b32 %r7461, 4, 0, %p532; + selp.b32 %r7559, %r7461, 0, %p514; + add.s32 %r10828, %r7447, 98528; + selp.b32 %r7462, 4, 0, %p533; + selp.b32 %r7561, %r7462, 0, %p514; + add.s32 %r10830, %r7447, 98532; + selp.b32 %r7463, 4, 0, %p534; + selp.b32 %r7563, %r7463, 0, %p514; + add.s32 %r10832, %r10792, 49152; + add.s32 %r10834, %r10792, 51200; + add.s32 %r10836, %r10792, 53248; + add.s32 %r10838, %r10792, 55296; + add.s32 %r10840, %r7447, 98816; + add.s32 %r10842, %r7447, 98820; + add.s32 %r10844, %r7447, 98848; + add.s32 %r10846, %r7447, 98852; + add.s32 %r10848, %r7447, 98880; + add.s32 %r10850, %r7447, 98884; + add.s32 %r10852, %r7447, 98912; + add.s32 %r10854, %r7447, 98916; + add.s32 %r10856, %r7447, 98944; + add.s32 %r10858, %r7447, 98948; + add.s32 %r10860, %r7447, 98976; + add.s32 %r10862, %r7447, 98980; + add.s32 %r10864, %r7447, 99008; + add.s32 %r10866, %r7447, 99012; + add.s32 %r10868, %r7447, 99040; + add.s32 %r10870, %r7447, 99044; + setp.gt.s32 %p535, %r7414, 1; + or.b32 %r840, %r770, 64; + or.b32 %r841, %r771, 64; + or.b32 %r842, %r772, 64; + or.b32 %r843, %r773, 64; + or.b32 %r844, %r774, 64; + or.b32 %r845, %r775, 64; + or.b32 %r846, %r776, 64; + or.b32 %r847, %r777, 64; + or.b32 %r848, %r778, 64; + or.b32 %r849, %r779, 64; + or.b32 %r850, %r780, 64; + or.b32 %r851, %r781, 64; + or.b32 %r852, %r782, 64; + or.b32 %r853, %r783, 64; + or.b32 %r854, %r784, 64; + or.b32 %r855, %r785, 64; + or.b32 %r856, %r7396, 64; + or.b32 %r857, %r7397, 64; + or.b32 %r858, %r7398, 64; + or.b32 %r859, %r7399, 64; + setp.lt.s32 %p536, %r856, %r2338; + setp.lt.s32 %p537, %r857, %r2338; + setp.lt.s32 %p538, %r858, %r2338; + setp.lt.s32 %p539, %r859, %r2338; + add.s32 %r10872, %r10792, 16384; + selp.b32 %r7464, 16, 0, %p536; + selp.b32 %r7605, %r7464, 0, %p535; + add.s32 %r10874, %r10792, 18432; + selp.b32 %r7465, 16, 0, %p537; + selp.b32 %r7607, %r7465, 0, %p535; + add.s32 %r10876, %r10792, 20480; + selp.b32 %r7466, 16, 0, %p538; + selp.b32 %r7609, %r7466, 0, %p535; + add.s32 %r10878, %r10792, 22528; + selp.b32 %r7467, 16, 0, %p539; + selp.b32 %r7611, %r7467, 0, %p535; + setp.lt.s32 %p540, %r840, %r2338; + setp.lt.s32 %p541, %r841, %r2338; + setp.lt.s32 %p542, %r842, %r2338; + setp.lt.s32 %p543, %r843, %r2338; + setp.lt.s32 %p544, %r844, %r2338; + setp.lt.s32 %p545, %r845, %r2338; + setp.lt.s32 %p546, %r846, %r2338; + setp.lt.s32 %p547, %r847, %r2338; + setp.lt.s32 %p548, %r848, %r2338; + setp.lt.s32 %p549, %r849, %r2338; + setp.lt.s32 %p550, %r850, %r2338; + setp.lt.s32 %p551, %r851, %r2338; + setp.lt.s32 %p552, %r852, %r2338; + setp.lt.s32 %p553, %r853, %r2338; + setp.lt.s32 %p554, %r854, %r2338; + setp.lt.s32 %p555, %r855, %r2338; + add.s32 %r10880, %r7447, 98560; + selp.b32 %r7468, 4, 0, %p540; + selp.b32 %r7613, %r7468, 0, %p535; + add.s32 %r10882, %r7447, 98564; + selp.b32 %r7469, 4, 0, %p541; + selp.b32 %r7615, %r7469, 0, %p535; + add.s32 %r10884, %r7447, 98592; + selp.b32 %r7470, 4, 0, %p542; + selp.b32 %r7617, %r7470, 0, %p535; + add.s32 %r10886, %r7447, 98596; + selp.b32 %r7471, 4, 0, %p543; + selp.b32 %r7619, %r7471, 0, %p535; + add.s32 %r10888, %r7447, 98624; + selp.b32 %r7472, 4, 0, %p544; + selp.b32 %r7621, %r7472, 0, %p535; + add.s32 %r10890, %r7447, 98628; + selp.b32 %r7473, 4, 0, %p545; + selp.b32 %r7623, %r7473, 0, %p535; + add.s32 %r10892, %r7447, 98656; + selp.b32 %r7474, 4, 0, %p546; + selp.b32 %r7625, %r7474, 0, %p535; + add.s32 %r10894, %r7447, 98660; + selp.b32 %r7475, 4, 0, %p547; + selp.b32 %r7627, %r7475, 0, %p535; + add.s32 %r10896, %r7447, 98688; + selp.b32 %r7476, 4, 0, %p548; + selp.b32 %r7629, %r7476, 0, %p535; + add.s32 %r10898, %r7447, 98692; + selp.b32 %r7477, 4, 0, %p549; + selp.b32 %r7631, %r7477, 0, %p535; + add.s32 %r10900, %r7447, 98720; + selp.b32 %r7478, 4, 0, %p550; + selp.b32 %r7633, %r7478, 0, %p535; + add.s32 %r10902, %r7447, 98724; + selp.b32 %r7479, 4, 0, %p551; + selp.b32 %r7635, %r7479, 0, %p535; + add.s32 %r10904, %r7447, 98752; + selp.b32 %r7480, 4, 0, %p552; + selp.b32 %r7637, %r7480, 0, %p535; + add.s32 %r10906, %r7447, 98756; + selp.b32 %r7481, 4, 0, %p553; + selp.b32 %r7639, %r7481, 0, %p535; + add.s32 %r10908, %r7447, 98784; + selp.b32 %r7482, 4, 0, %p554; + selp.b32 %r7641, %r7482, 0, %p535; + add.s32 %r10910, %r7447, 98788; + selp.b32 %r7483, 4, 0, %p555; + selp.b32 %r7643, %r7483, 0, %p535; + add.s32 %r10912, %r10792, 65536; + add.s32 %r10914, %r10792, 67584; + add.s32 %r10916, %r10792, 69632; + add.s32 %r10918, %r10792, 71680; + add.s32 %r10920, %r7447, 99072; + add.s32 %r10922, %r7447, 99076; + add.s32 %r10924, %r7447, 99104; + add.s32 %r10926, %r7447, 99108; + add.s32 %r10928, %r7447, 99136; + add.s32 %r10930, %r7447, 99140; + add.s32 %r10932, %r7447, 99168; + add.s32 %r10934, %r7447, 99172; + add.s32 %r10936, %r7447, 99200; + add.s32 %r10938, %r7447, 99204; + add.s32 %r10940, %r7447, 99232; + add.s32 %r10942, %r7447, 99236; + add.s32 %r10944, %r7447, 99264; + add.s32 %r10946, %r7447, 99268; + add.s32 %r10948, %r7447, 99296; + add.s32 %r10950, %r7447, 99300; + add.s32 %r920, %r7414, -2; + add.s32 %r921, %r7414, -1; + cvt.s64.s32 %rd102, %r7433; + cvt.s64.s32 %rd103, %r7434; + cvt.s64.s32 %rd104, %r7435; + cvt.s64.s32 %rd105, %r7436; + cvt.s64.s32 %rd106, %r7437; + cvt.s64.s32 %rd107, %r7438; + cvt.s64.s32 %rd108, %r7439; + cvt.s64.s32 %rd109, %r7440; + setp.gt.s32 %p556, %r760, 0; + setp.lt.s32 %p557, %r7429, %r2338; + setp.lt.s32 %p558, %r7430, %r2338; + setp.lt.s32 %p559, %r7431, %r2338; + setp.lt.s32 %p560, %r7432, %r2338; + selp.b32 %r7484, 16, 0, %p557; + selp.b32 %r10793, %r7484, 0, %p556; + selp.b32 %r7485, 16, 0, %p558; + selp.b32 %r10795, %r7485, 0, %p556; + selp.b32 %r7486, 16, 0, %p559; + selp.b32 %r10797, %r7486, 0, %p556; + selp.b32 %r7487, 16, 0, %p560; + selp.b32 %r10799, %r7487, 0, %p556; + setp.lt.s32 %p561, %r744, %r2338; + setp.lt.s32 %p562, %r745, %r2338; + setp.lt.s32 %p563, %r746, %r2338; + setp.lt.s32 %p564, %r747, %r2338; + setp.lt.s32 %p565, %r748, %r2338; + setp.lt.s32 %p566, %r749, %r2338; + setp.lt.s32 %p567, %r750, %r2338; + setp.lt.s32 %p568, %r751, %r2338; + setp.lt.s32 %p569, %r752, %r2338; + setp.lt.s32 %p570, %r753, %r2338; + setp.lt.s32 %p571, %r754, %r2338; + setp.lt.s32 %p572, %r755, %r2338; + setp.lt.s32 %p573, %r756, %r2338; + setp.lt.s32 %p574, %r757, %r2338; + setp.lt.s32 %p575, %r758, %r2338; + setp.lt.s32 %p576, %r759, %r2338; + cvt.s64.s32 %rd110, %r744; + cvt.s64.s32 %rd111, %r748; + cvt.s64.s32 %rd112, %r749; + cvt.s64.s32 %rd113, %r750; + cvt.s64.s32 %rd114, %r751; + cvt.s64.s32 %rd115, %r752; + cvt.s64.s32 %rd116, %r753; + cvt.s64.s32 %rd117, %r754; + cvt.s64.s32 %rd118, %r755; + cvt.s64.s32 %rd119, %r756; + cvt.s64.s32 %rd120, %r757; + cvt.s64.s32 %rd121, %r758; + cvt.s64.s32 %rd122, %r759; + selp.b32 %r7488, 4, 0, %p561; + selp.b32 %r10801, %r7488, 0, %p556; + selp.b32 %r7489, 4, 0, %p562; + selp.b32 %r10803, %r7489, 0, %p556; + selp.b32 %r7490, 4, 0, %p563; + selp.b32 %r10805, %r7490, 0, %p556; + selp.b32 %r7491, 4, 0, %p564; + selp.b32 %r10807, %r7491, 0, %p556; + selp.b32 %r7492, 4, 0, %p565; + selp.b32 %r10809, %r7492, 0, %p556; + selp.b32 %r7493, 4, 0, %p566; + selp.b32 %r10811, %r7493, 0, %p556; + selp.b32 %r7494, 4, 0, %p567; + selp.b32 %r10813, %r7494, 0, %p556; + selp.b32 %r7495, 4, 0, %p568; + selp.b32 %r10815, %r7495, 0, %p556; + selp.b32 %r7496, 4, 0, %p569; + selp.b32 %r10817, %r7496, 0, %p556; + selp.b32 %r7497, 4, 0, %p570; + selp.b32 %r10819, %r7497, 0, %p556; + selp.b32 %r7498, 4, 0, %p571; + selp.b32 %r10821, %r7498, 0, %p556; + selp.b32 %r7499, 4, 0, %p572; + selp.b32 %r10823, %r7499, 0, %p556; + selp.b32 %r7500, 4, 0, %p573; + selp.b32 %r10825, %r7500, 0, %p556; + selp.b32 %r7501, 4, 0, %p574; + selp.b32 %r10827, %r7501, 0, %p556; + selp.b32 %r7502, 4, 0, %p575; + selp.b32 %r10829, %r7502, 0, %p556; + selp.b32 %r7503, 4, 0, %p576; + selp.b32 %r10831, %r7503, 0, %p556; + setp.gt.s32 %p577, %r7441, 1; + or.b32 %r942, %r744, 64; + or.b32 %r943, %r745, 64; + or.b32 %r944, %r746, 64; + or.b32 %r945, %r747, 64; + or.b32 %r946, %r748, 64; + or.b32 %r947, %r749, 64; + or.b32 %r948, %r750, 64; + or.b32 %r949, %r751, 64; + or.b32 %r950, %r752, 64; + or.b32 %r951, %r753, 64; + or.b32 %r952, %r754, 64; + or.b32 %r953, %r755, 64; + or.b32 %r954, %r756, 64; + or.b32 %r955, %r757, 64; + or.b32 %r956, %r758, 64; + or.b32 %r957, %r759, 64; + or.b32 %r958, %r7429, 64; + or.b32 %r959, %r7430, 64; + or.b32 %r960, %r7431, 64; + or.b32 %r961, %r7432, 64; + setp.lt.s32 %p578, %r958, %r2338; + setp.lt.s32 %p579, %r959, %r2338; + setp.lt.s32 %p580, %r960, %r2338; + setp.lt.s32 %p581, %r961, %r2338; + selp.b32 %r7504, 16, 0, %p578; + selp.b32 %r10873, %r7504, 0, %p577; + selp.b32 %r7505, 16, 0, %p579; + selp.b32 %r10875, %r7505, 0, %p577; + selp.b32 %r7506, 16, 0, %p580; + selp.b32 %r10877, %r7506, 0, %p577; + selp.b32 %r7507, 16, 0, %p581; + selp.b32 %r10879, %r7507, 0, %p577; + setp.lt.s32 %p582, %r942, %r2338; + setp.lt.s32 %p583, %r943, %r2338; + setp.lt.s32 %p584, %r944, %r2338; + setp.lt.s32 %p585, %r945, %r2338; + setp.lt.s32 %p586, %r946, %r2338; + setp.lt.s32 %p587, %r947, %r2338; + setp.lt.s32 %p588, %r948, %r2338; + setp.lt.s32 %p589, %r949, %r2338; + setp.lt.s32 %p590, %r950, %r2338; + setp.lt.s32 %p591, %r951, %r2338; + setp.lt.s32 %p592, %r952, %r2338; + setp.lt.s32 %p593, %r953, %r2338; + setp.lt.s32 %p594, %r954, %r2338; + setp.lt.s32 %p595, %r955, %r2338; + setp.lt.s32 %p596, %r956, %r2338; + setp.lt.s32 %p597, %r957, %r2338; + selp.b32 %r7508, 4, 0, %p582; + selp.b32 %r10881, %r7508, 0, %p577; + selp.b32 %r7509, 4, 0, %p583; + selp.b32 %r10883, %r7509, 0, %p577; + selp.b32 %r7510, 4, 0, %p584; + selp.b32 %r10885, %r7510, 0, %p577; + selp.b32 %r7511, 4, 0, %p585; + selp.b32 %r10887, %r7511, 0, %p577; + selp.b32 %r7512, 4, 0, %p586; + selp.b32 %r10889, %r7512, 0, %p577; + selp.b32 %r7513, 4, 0, %p587; + selp.b32 %r10891, %r7513, 0, %p577; + selp.b32 %r7514, 4, 0, %p588; + selp.b32 %r10893, %r7514, 0, %p577; + selp.b32 %r7515, 4, 0, %p589; + selp.b32 %r10895, %r7515, 0, %p577; + selp.b32 %r7516, 4, 0, %p590; + selp.b32 %r10897, %r7516, 0, %p577; + selp.b32 %r7517, 4, 0, %p591; + selp.b32 %r10899, %r7517, 0, %p577; + selp.b32 %r7518, 4, 0, %p592; + selp.b32 %r10901, %r7518, 0, %p577; + selp.b32 %r7519, 4, 0, %p593; + selp.b32 %r10903, %r7519, 0, %p577; + selp.b32 %r7520, 4, 0, %p594; + selp.b32 %r10905, %r7520, 0, %p577; + selp.b32 %r7521, 4, 0, %p595; + selp.b32 %r10907, %r7521, 0, %p577; + selp.b32 %r7522, 4, 0, %p596; + selp.b32 %r10909, %r7522, 0, %p577; + selp.b32 %r7523, 4, 0, %p597; + selp.b32 %r10911, %r7523, 0, %p577; + add.s32 %r982, %r7441, -2; + add.s32 %r983, %r7441, -1; +$L__tmp12: + .loc 1 262 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:262:30 + max.s32 %r984, %r7414, 1; + max.s32 %r985, %r7441, 1; + mul.wide.u32 %rd123, %r7, 4; + mov.b32 %r14726, 0f00000000; + mov.b64 %rd1120, 0; + shl.b64 %rd635, %rd78, 1; + shl.b64 %rd637, %rd79, 1; + shl.b64 %rd639, %rd80, 1; + shl.b64 %rd641, %rd81, 1; + shl.b64 %rd644, %rd82, 1; + shl.b64 %rd646, %rd83, 1; + shl.b64 %rd648, %rd84, 1; + shl.b64 %rd650, %rd85, 1; + shl.b64 %rd652, %rd86, 2; + shl.b64 %rd653, %rd87, 2; + shl.b64 %rd654, %rd88, 2; + shl.b64 %rd655, %rd89, 2; + shl.b64 %rd656, %rd90, 2; + shl.b64 %rd657, %rd91, 2; + shl.b64 %rd658, %rd92, 2; + shl.b64 %rd659, %rd93, 2; + shl.b64 %rd660, %rd94, 2; + shl.b64 %rd661, %rd95, 2; + shl.b64 %rd662, %rd96, 2; + shl.b64 %rd663, %rd97, 2; + shl.b64 %rd664, %rd98, 2; + shl.b64 %rd665, %rd99, 2; + shl.b64 %rd666, %rd100, 2; + shl.b64 %rd667, %rd101, 2; + shl.b64 %rd890, %rd102, 1; + shl.b64 %rd892, %rd103, 1; + shl.b64 %rd894, %rd104, 1; + shl.b64 %rd896, %rd105, 1; + shl.b64 %rd899, %rd106, 1; + shl.b64 %rd901, %rd107, 1; + shl.b64 %rd903, %rd108, 1; + shl.b64 %rd905, %rd109, 1; + shl.b64 %rd907, %rd110, 2; + shl.b64 %rd913, %rd111, 2; + shl.b64 %rd914, %rd112, 2; + shl.b64 %rd915, %rd113, 2; + shl.b64 %rd916, %rd114, 2; + shl.b64 %rd917, %rd115, 2; + shl.b64 %rd918, %rd116, 2; + shl.b64 %rd919, %rd117, 2; + shl.b64 %rd920, %rd118, 2; + shl.b64 %rd921, %rd119, 2; + shl.b64 %rd922, %rd120, 2; + shl.b64 %rd923, %rd121, 2; + shl.b64 %rd924, %rd122, 2; + mov.b32 %r14727, %r14726; + mov.b32 %r14728, %r14726; + mov.b32 %r14729, %r14726; + mov.b32 %r14730, %r14726; + mov.b32 %r14731, %r14726; + mov.b32 %r14732, %r14726; + mov.b32 %r14733, %r14726; + mov.b32 %r14734, %r14726; + mov.b32 %r14735, %r14726; + mov.b32 %r14736, %r14726; + mov.b32 %r14737, %r14726; + mov.b32 %r14738, %r14726; + mov.b32 %r14739, %r14726; + mov.b32 %r14740, %r14726; + mov.b32 %r14741, %r14726; + mov.b32 %r14742, %r14726; + mov.b32 %r14743, %r14726; + mov.b32 %r14744, %r14726; + mov.b32 %r14745, %r14726; + mov.b32 %r14746, %r14726; + mov.b32 %r14747, %r14726; + mov.b32 %r14748, %r14726; + mov.b32 %r14749, %r14726; + mov.b32 %r14750, %r14726; + mov.b32 %r14751, %r14726; + mov.b32 %r14752, %r14726; + mov.b32 %r14753, %r14726; + mov.b32 %r14754, %r14726; + mov.b32 %r14755, %r14726; + mov.b32 %r14756, %r14726; + mov.b32 %r14757, %r14726; + mov.b32 %r14758, %r14726; + mov.b32 %r14759, %r14726; + mov.b32 %r14760, %r14726; + mov.b32 %r14761, %r14726; + mov.b32 %r14762, %r14726; + mov.b32 %r14763, %r14726; + mov.b32 %r14764, %r14726; + mov.b32 %r14765, %r14726; + mov.b32 %r14766, %r14726; + mov.b32 %r14767, %r14726; + mov.b32 %r14768, %r14726; + mov.b32 %r14769, %r14726; + mov.b32 %r14770, %r14726; + mov.b32 %r14771, %r14726; + mov.b32 %r14772, %r14726; + mov.b32 %r14773, %r14726; + mov.b32 %r14774, %r14726; + mov.b32 %r14775, %r14726; + mov.b32 %r14776, %r14726; + mov.b32 %r14777, %r14726; + mov.b32 %r14778, %r14726; + mov.b32 %r14779, %r14726; + mov.b32 %r14780, %r14726; + mov.b32 %r14781, %r14726; + mov.b32 %r14782, %r14726; + mov.b32 %r14783, %r14726; + mov.b32 %r14784, %r14726; + mov.b32 %r14785, %r14726; + mov.b32 %r14786, %r14726; + mov.b32 %r14787, %r14726; + mov.b32 %r14788, %r14726; + mov.b32 %r14789, %r14726; + mov.b32 %r14662, %r14726; + mov.b32 %r14663, %r14726; + mov.b32 %r14664, %r14726; + mov.b32 %r14665, %r14726; + mov.b32 %r14666, %r14726; + mov.b32 %r14667, %r14726; + mov.b32 %r14668, %r14726; + mov.b32 %r14669, %r14726; + mov.b32 %r14670, %r14726; + mov.b32 %r14671, %r14726; + mov.b32 %r14672, %r14726; + mov.b32 %r14673, %r14726; + mov.b32 %r14674, %r14726; + mov.b32 %r14675, %r14726; + mov.b32 %r14676, %r14726; + mov.b32 %r14677, %r14726; + mov.b32 %r14678, %r14726; + mov.b32 %r14679, %r14726; + mov.b32 %r14680, %r14726; + mov.b32 %r14681, %r14726; + mov.b32 %r14682, %r14726; + mov.b32 %r14683, %r14726; + mov.b32 %r14684, %r14726; + mov.b32 %r14685, %r14726; + mov.b32 %r14686, %r14726; + mov.b32 %r14687, %r14726; + mov.b32 %r14688, %r14726; + mov.b32 %r14689, %r14726; + mov.b32 %r14690, %r14726; + mov.b32 %r14691, %r14726; + mov.b32 %r14692, %r14726; + mov.b32 %r14693, %r14726; + mov.b32 %r14694, %r14726; + mov.b32 %r14695, %r14726; + mov.b32 %r14696, %r14726; + mov.b32 %r14697, %r14726; + mov.b32 %r14698, %r14726; + mov.b32 %r14699, %r14726; + mov.b32 %r14700, %r14726; + mov.b32 %r14701, %r14726; + mov.b32 %r14702, %r14726; + mov.b32 %r14703, %r14726; + mov.b32 %r14704, %r14726; + mov.b32 %r14705, %r14726; + mov.b32 %r14706, %r14726; + mov.b32 %r14707, %r14726; + mov.b32 %r14708, %r14726; + mov.b32 %r14709, %r14726; + mov.b32 %r14710, %r14726; + mov.b32 %r14711, %r14726; + mov.b32 %r14712, %r14726; + mov.b32 %r14713, %r14726; + mov.b32 %r14714, %r14726; + mov.b32 %r14715, %r14726; + mov.b32 %r14716, %r14726; + mov.b32 %r14717, %r14726; + mov.b32 %r14718, %r14726; + mov.b32 %r14719, %r14726; + mov.b32 %r14720, %r14726; + mov.b32 %r14721, %r14726; + mov.b32 %r14722, %r14726; + mov.b32 %r14723, %r14726; + mov.b32 %r14724, %r14726; + mov.b32 %r14725, %r14726; + bra.uni $L__BB0_9; +$L__BB0_15: // %._crit_edge1874 + // in Loop: Header=BB0_9 Depth=1 +$L__tmp13: + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + // wait for regs: %r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725,%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789 + wgmma.wait_group.sync.aligned 0; + // end inline asm + cp.async.wait_group 0; + bar.sync 0; +$L__tmp14: + .loc 1 262 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:262:30 + add.s64 %rd1120, %rd1120, 1; + setp.ne.b64 %p1252, %rd1120, 4; + @%p1252 bra $L__BB0_9; + bra.uni $L__BB0_16; +$L__BB0_9: // =>This Loop Header: Depth=1 + // Child Loop BB0_11 Depth 2 + // Child Loop BB0_14 Depth 2 + .loc 1 0 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:30 + setp.eq.b32 %p1068, %r786, 0; + setp.lt.s32 %p662, %r693, 1; + .loc 1 263 51 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:263:51 + add.s64 %rd633, %rd1120, %rd123; + .loc 1 266 44 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:266:44 + cvt.u32.u64 %r7684, %rd633; + shl.b32 %r7685, %r7684, 7; + add.s32 %r7686, %r7685, %r687; + .loc 1 267 46 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:267:46 + mad.lo.s32 %r7687, %r3, %r7684, %r688; + .loc 1 269 50 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:269:50 + add.s32 %r7688, %r689, %r7684; + mul.lo.s32 %r7689, %r7688, %r2338; + .loc 1 271 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:271:21 + mad.wide.s32 %rd125, %r7686, 2, %rd178; + .loc 1 272 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:272:23 + mad.wide.s32 %rd126, %r7687, 2, %rd181; + .loc 1 275 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:275:25 + mul.wide.s32 %rd634, %r7689, 4; + add.s64 %rd127, %rd179, %rd634; + .loc 1 276 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:276:29 + add.s64 %rd128, %rd180, %rd634; +$L__tmp15: + .loc 1 583 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd636, %rd125, %rd635; + add.s64 %rd638, %rd125, %rd637; + add.s64 %rd640, %rd125, %rd639; + add.s64 %rd642, %rd125, %rd641; + .loc 1 583 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b64 %rd643, %rd75, 1; + add.s64 %rd553, %rd636, %rd643; + add.s64 %rd554, %rd638, %rd643; + add.s64 %rd555, %rd640, %rd643; + add.s64 %rd556, %rd642, %rd643; + .loc 1 584 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd645, %rd126, %rd644; + add.s64 %rd647, %rd126, %rd646; + add.s64 %rd649, %rd126, %rd648; + add.s64 %rd651, %rd126, %rd650; + .loc 1 584 51 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:51 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd573, %rd645, %rd643; + add.s64 %rd574, %rd647, %rd643; + add.s64 %rd575, %rd649, %rd643; + add.s64 %rd576, %rd651, %rd643; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + cp.async.cg.shared.global [ %r10792 + 0 ], [ %rd553 + 0 ], 0x10, %r7525; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10794 + 0 ], [ %rd554 + 0 ], 0x10, %r7527; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10796 + 0 ], [ %rd555 + 0 ], 0x10, %r7529; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10798 + 0 ], [ %rd556 + 0 ], 0x10, %r7531; + // end inline asm + cp.async.commit_group; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd557, %rd127, %rd652; + add.s64 %rd558, %rd127, %rd653; + add.s64 %rd559, %rd127, %rd654; + add.s64 %rd560, %rd127, %rd655; + add.s64 %rd561, %rd127, %rd656; + add.s64 %rd562, %rd127, %rd657; + add.s64 %rd563, %rd127, %rd658; + add.s64 %rd564, %rd127, %rd659; + add.s64 %rd565, %rd127, %rd660; + add.s64 %rd566, %rd127, %rd661; + add.s64 %rd567, %rd127, %rd662; + add.s64 %rd568, %rd127, %rd663; + add.s64 %rd569, %rd127, %rd664; + add.s64 %rd570, %rd127, %rd665; + add.s64 %rd571, %rd127, %rd666; + add.s64 %rd572, %rd127, %rd667; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10800 + 0 ], [ %rd557 + 0 ], 0x4, %r7533; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10802 + 0 ], [ %rd558 + 0 ], 0x4, %r7535; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10804 + 0 ], [ %rd559 + 0 ], 0x4, %r7537; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10806 + 0 ], [ %rd560 + 0 ], 0x4, %r7539; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10808 + 0 ], [ %rd561 + 0 ], 0x4, %r7541; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10810 + 0 ], [ %rd562 + 0 ], 0x4, %r7543; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10812 + 0 ], [ %rd563 + 0 ], 0x4, %r7545; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10814 + 0 ], [ %rd564 + 0 ], 0x4, %r7547; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10816 + 0 ], [ %rd565 + 0 ], 0x4, %r7549; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10818 + 0 ], [ %rd566 + 0 ], 0x4, %r7551; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10820 + 0 ], [ %rd567 + 0 ], 0x4, %r7553; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10822 + 0 ], [ %rd568 + 0 ], 0x4, %r7555; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10824 + 0 ], [ %rd569 + 0 ], 0x4, %r7557; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10826 + 0 ], [ %rd570 + 0 ], 0x4, %r7559; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10828 + 0 ], [ %rd571 + 0 ], 0x4, %r7561; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10830 + 0 ], [ %rd572 + 0 ], 0x4, %r7563; + // end inline asm + cp.async.commit_group; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + cp.async.cg.shared.global [ %r10832 + 0 ], [ %rd573 + 0 ], 0x10, %r7525; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10834 + 0 ], [ %rd574 + 0 ], 0x10, %r7527; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10836 + 0 ], [ %rd575 + 0 ], 0x10, %r7529; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10838 + 0 ], [ %rd576 + 0 ], 0x10, %r7531; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd577, %rd128, %rd652; + add.s64 %rd578, %rd128, %rd653; + add.s64 %rd579, %rd128, %rd654; + add.s64 %rd580, %rd128, %rd655; + add.s64 %rd581, %rd128, %rd656; + add.s64 %rd582, %rd128, %rd657; + add.s64 %rd583, %rd128, %rd658; + add.s64 %rd584, %rd128, %rd659; + add.s64 %rd585, %rd128, %rd660; + add.s64 %rd586, %rd128, %rd661; + add.s64 %rd587, %rd128, %rd662; + add.s64 %rd588, %rd128, %rd663; + add.s64 %rd589, %rd128, %rd664; + add.s64 %rd590, %rd128, %rd665; + add.s64 %rd591, %rd128, %rd666; + add.s64 %rd592, %rd128, %rd667; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10840 + 0 ], [ %rd577 + 0 ], 0x4, %r7533; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10842 + 0 ], [ %rd578 + 0 ], 0x4, %r7535; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10844 + 0 ], [ %rd579 + 0 ], 0x4, %r7537; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10846 + 0 ], [ %rd580 + 0 ], 0x4, %r7539; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10848 + 0 ], [ %rd581 + 0 ], 0x4, %r7541; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10850 + 0 ], [ %rd582 + 0 ], 0x4, %r7543; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10852 + 0 ], [ %rd583 + 0 ], 0x4, %r7545; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10854 + 0 ], [ %rd584 + 0 ], 0x4, %r7547; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10856 + 0 ], [ %rd585 + 0 ], 0x4, %r7549; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10858 + 0 ], [ %rd586 + 0 ], 0x4, %r7551; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10860 + 0 ], [ %rd587 + 0 ], 0x4, %r7553; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10862 + 0 ], [ %rd588 + 0 ], 0x4, %r7555; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10864 + 0 ], [ %rd589 + 0 ], 0x4, %r7557; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10866 + 0 ], [ %rd590 + 0 ], 0x4, %r7559; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10868 + 0 ], [ %rd591 + 0 ], 0x4, %r7561; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10870 + 0 ], [ %rd592 + 0 ], 0x4, %r7563; + // end inline asm + cp.async.commit_group; + .loc 1 608 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd1128, %rd553, 524288; + add.s64 %rd1127, %rd554, 524288; + add.s64 %rd1126, %rd555, 524288; + add.s64 %rd1125, %rd556, 524288; + .loc 1 609 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd1124, %rd573, 16384; + add.s64 %rd1123, %rd574, 16384; + add.s64 %rd1122, %rd575, 16384; + add.s64 %rd1121, %rd576, 16384; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + bar.sync 0; + // begin inline asm + cp.async.cg.shared.global [ %r10872 + 0 ], [ %rd1128 + 0 ], 0x10, %r7605; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10874 + 0 ], [ %rd1127 + 0 ], 0x10, %r7607; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10876 + 0 ], [ %rd1126 + 0 ], 0x10, %r7609; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10878 + 0 ], [ %rd1125 + 0 ], 0x10, %r7611; + // end inline asm + cp.async.commit_group; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd597, %rd557, 256; + add.s64 %rd598, %rd558, 256; + add.s64 %rd599, %rd559, 256; + add.s64 %rd600, %rd560, 256; + add.s64 %rd601, %rd561, 256; + add.s64 %rd602, %rd562, 256; + add.s64 %rd603, %rd563, 256; + add.s64 %rd604, %rd564, 256; + add.s64 %rd605, %rd565, 256; + add.s64 %rd606, %rd566, 256; + add.s64 %rd607, %rd567, 256; + add.s64 %rd608, %rd568, 256; + add.s64 %rd609, %rd569, 256; + add.s64 %rd610, %rd570, 256; + add.s64 %rd611, %rd571, 256; + add.s64 %rd612, %rd572, 256; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10880 + 0 ], [ %rd597 + 0 ], 0x4, %r7613; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10882 + 0 ], [ %rd598 + 0 ], 0x4, %r7615; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10884 + 0 ], [ %rd599 + 0 ], 0x4, %r7617; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10886 + 0 ], [ %rd600 + 0 ], 0x4, %r7619; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10888 + 0 ], [ %rd601 + 0 ], 0x4, %r7621; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10890 + 0 ], [ %rd602 + 0 ], 0x4, %r7623; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10892 + 0 ], [ %rd603 + 0 ], 0x4, %r7625; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10894 + 0 ], [ %rd604 + 0 ], 0x4, %r7627; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10896 + 0 ], [ %rd605 + 0 ], 0x4, %r7629; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10898 + 0 ], [ %rd606 + 0 ], 0x4, %r7631; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10900 + 0 ], [ %rd607 + 0 ], 0x4, %r7633; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10902 + 0 ], [ %rd608 + 0 ], 0x4, %r7635; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10904 + 0 ], [ %rd609 + 0 ], 0x4, %r7637; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10906 + 0 ], [ %rd610 + 0 ], 0x4, %r7639; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10908 + 0 ], [ %rd611 + 0 ], 0x4, %r7641; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10910 + 0 ], [ %rd612 + 0 ], 0x4, %r7643; + // end inline asm + cp.async.commit_group; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + cp.async.cg.shared.global [ %r10912 + 0 ], [ %rd1124 + 0 ], 0x10, %r7605; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10914 + 0 ], [ %rd1123 + 0 ], 0x10, %r7607; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10916 + 0 ], [ %rd1122 + 0 ], 0x10, %r7609; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10918 + 0 ], [ %rd1121 + 0 ], 0x10, %r7611; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd617, %rd577, 256; + add.s64 %rd618, %rd578, 256; + add.s64 %rd619, %rd579, 256; + add.s64 %rd620, %rd580, 256; + add.s64 %rd621, %rd581, 256; + add.s64 %rd622, %rd582, 256; + add.s64 %rd623, %rd583, 256; + add.s64 %rd624, %rd584, 256; + add.s64 %rd625, %rd585, 256; + add.s64 %rd626, %rd586, 256; + add.s64 %rd627, %rd587, 256; + add.s64 %rd628, %rd588, 256; + add.s64 %rd629, %rd589, 256; + add.s64 %rd630, %rd590, 256; + add.s64 %rd631, %rd591, 256; + add.s64 %rd632, %rd592, 256; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10920 + 0 ], [ %rd617 + 0 ], 0x4, %r7613; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10922 + 0 ], [ %rd618 + 0 ], 0x4, %r7615; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10924 + 0 ], [ %rd619 + 0 ], 0x4, %r7617; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10926 + 0 ], [ %rd620 + 0 ], 0x4, %r7619; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10928 + 0 ], [ %rd621 + 0 ], 0x4, %r7621; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10930 + 0 ], [ %rd622 + 0 ], 0x4, %r7623; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10932 + 0 ], [ %rd623 + 0 ], 0x4, %r7625; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10934 + 0 ], [ %rd624 + 0 ], 0x4, %r7627; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10936 + 0 ], [ %rd625 + 0 ], 0x4, %r7629; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10938 + 0 ], [ %rd626 + 0 ], 0x4, %r7631; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10940 + 0 ], [ %rd627 + 0 ], 0x4, %r7633; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10942 + 0 ], [ %rd628 + 0 ], 0x4, %r7635; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10944 + 0 ], [ %rd629 + 0 ], 0x4, %r7637; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10946 + 0 ], [ %rd630 + 0 ], 0x4, %r7639; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10948 + 0 ], [ %rd631 + 0 ], 0x4, %r7641; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10950 + 0 ], [ %rd632 + 0 ], 0x4, %r7643; + // end inline asm + cp.async.commit_group; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + @%p662 bra $L__BB0_12; +// %bb.10: // %.lr.ph1700.preheader + // in Loop: Header=BB0_9 Depth=1 + .loc 1 0 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:28 + mov.b32 %r8245, 0; + mov.b32 %r14635, 1; + mov.b32 %r14634, -1; + mov.b32 %r14617, 64; + mov.b32 %r14618, %r770; + mov.b32 %r14619, %r771; + mov.b32 %r14620, %r772; + mov.b32 %r14621, %r773; + mov.b32 %r14622, %r774; + mov.b32 %r14623, %r775; + mov.b32 %r14624, %r776; + mov.b32 %r14625, %r777; + mov.b32 %r14626, %r778; + mov.b32 %r14627, %r779; + mov.b32 %r14628, %r780; + mov.b32 %r14629, %r781; + mov.b32 %r14630, %r782; + mov.b32 %r14631, %r783; + mov.b32 %r14632, %r784; + mov.b32 %r14633, %r785; + mov.b32 %r14636, %r14634; + mov.b32 %r14637, %r14635; + mov.b32 %r14638, %r855; + mov.b32 %r14639, %r854; + mov.b32 %r14640, %r853; + mov.b32 %r14641, %r852; + mov.b32 %r14642, %r851; + mov.b32 %r14643, %r850; + mov.b32 %r14644, %r849; + mov.b32 %r14645, %r848; + mov.b32 %r14646, %r847; + mov.b32 %r14647, %r846; + mov.b32 %r14648, %r845; + mov.b32 %r14649, %r844; + mov.b32 %r14650, %r843; + mov.b32 %r14651, %r842; + mov.b32 %r14652, %r841; + mov.b32 %r14653, %r840; + mov.b32 %r14654, %r856; + mov.b32 %r14655, %r857; + mov.b32 %r14656, %r858; + mov.b32 %r14657, %r859; + mov.b32 %r14658, %r856; + mov.b32 %r14659, %r857; + mov.b32 %r14660, %r858; + mov.b32 %r14661, %r859; + mov.b32 %r14790, %r8245; + mov.b32 %r14791, %r785; + mov.b32 %r14792, %r784; + mov.b32 %r14793, %r783; + mov.b32 %r14794, %r782; + mov.b32 %r14795, %r781; + mov.b32 %r14796, %r780; + mov.b32 %r14797, %r779; + mov.b32 %r14798, %r778; + mov.b32 %r14799, %r777; + mov.b32 %r14800, %r776; + mov.b32 %r14801, %r775; + mov.b32 %r14802, %r774; + mov.b32 %r14803, %r773; + mov.b32 %r14804, %r772; + mov.b32 %r14805, %r771; + mov.b32 %r14806, %r770; +$L__BB0_11: // %.lr.ph1700 + // Parent Loop BB0_9 Depth=1 + // => This Inner Loop Header: Depth=2 + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p719, %r14790, %r920; + setp.lt.s32 %p685, %r14790, %r921; + add.s32 %r9944, %r14634, 1; + setp.gt.s32 %p720, %r9944, 1; + selp.b32 %r14634, 0, %r9944, %p720; + add.s32 %r9945, %r14636, 1; + setp.gt.s32 %p721, %r9945, 2; + selp.b32 %r14636, 0, %r9945, %p721; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p722, %r14618, %r2338; + setp.lt.s32 %p723, %r14619, %r2338; + setp.lt.s32 %p724, %r14620, %r2338; + setp.lt.s32 %p725, %r14621, %r2338; + setp.lt.s32 %p726, %r14622, %r2338; + setp.lt.s32 %p727, %r14623, %r2338; + setp.lt.s32 %p728, %r14624, %r2338; + setp.lt.s32 %p729, %r14625, %r2338; + setp.lt.s32 %p730, %r14626, %r2338; + setp.lt.s32 %p731, %r14627, %r2338; + setp.lt.s32 %p732, %r14628, %r2338; + setp.lt.s32 %p733, %r14629, %r2338; + setp.lt.s32 %p734, %r14630, %r2338; + setp.lt.s32 %p735, %r14631, %r2338; + setp.lt.s32 %p736, %r14632, %r2338; + setp.lt.s32 %p737, %r14633, %r2338; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cp.async.wait_group 4; + bar.sync 0; + shl.b32 %r9946, %r14636, 14; + add.s32 %r8247, %r7390, %r9946; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r9948, %r14634, 8; + add.s32 %r9949, %r7390, 98304; + add.s32 %r9950, %r9949, %r9948; + add.s32 %r9951, %r9950, %r787; + ld.shared.v2.b32 {%r9952, %r9953}, [%r9951]; + ld.shared.v2.b32 {%r9954, %r9955}, [%r9951+32]; + ld.shared.v2.b32 {%r9956, %r9957}, [%r9951+64]; + ld.shared.v2.b32 {%r9958, %r9959}, [%r9951+96]; + ld.shared.v2.b32 {%r9960, %r9961}, [%r9951+128]; + ld.shared.v2.b32 {%r9962, %r9963}, [%r9951+160]; + ld.shared.v2.b32 {%r9964, %r9965}, [%r9951+192]; + ld.shared.v2.b32 {%r9966, %r9967}, [%r9951+224]; + .loc 1 657 26 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:657:26 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.eq.f32 %p738, %r9952, 0fFF800000; + setp.eq.f32 %p739, %r9953, 0fFF800000; + setp.eq.f32 %p740, %r9954, 0fFF800000; + setp.eq.f32 %p741, %r9955, 0fFF800000; + setp.eq.f32 %p742, %r9956, 0fFF800000; + setp.eq.f32 %p743, %r9957, 0fFF800000; + setp.eq.f32 %p744, %r9958, 0fFF800000; + setp.eq.f32 %p745, %r9959, 0fFF800000; + setp.eq.f32 %p746, %r9960, 0fFF800000; + setp.eq.f32 %p747, %r9961, 0fFF800000; + setp.eq.f32 %p748, %r9962, 0fFF800000; + setp.eq.f32 %p749, %r9963, 0fFF800000; + setp.eq.f32 %p750, %r9964, 0fFF800000; + setp.eq.f32 %p751, %r9965, 0fFF800000; + setp.eq.f32 %p752, %r9966, 0fFF800000; + setp.eq.f32 %p753, %r9967, 0fFF800000; + .loc 1 657 46 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:657:46 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r9968, 0f00000000, %r9952, %p738; + selp.f32 %r9969, 0f00000000, %r9953, %p739; + selp.f32 %r9970, 0f00000000, %r9954, %p740; + selp.f32 %r9971, 0f00000000, %r9955, %p741; + selp.f32 %r9972, 0f00000000, %r9956, %p742; + selp.f32 %r9973, 0f00000000, %r9957, %p743; + selp.f32 %r9974, 0f00000000, %r9958, %p744; + selp.f32 %r9975, 0f00000000, %r9959, %p745; + selp.f32 %r9976, 0f00000000, %r9960, %p746; + selp.f32 %r9977, 0f00000000, %r9961, %p747; + selp.f32 %r9978, 0f00000000, %r9962, %p748; + selp.f32 %r9979, 0f00000000, %r9963, %p749; + selp.f32 %r9980, 0f00000000, %r9964, %p750; + selp.f32 %r9981, 0f00000000, %r9965, %p751; + selp.f32 %r9982, 0f00000000, %r9966, %p752; + selp.f32 %r9983, 0f00000000, %r9967, %p753; + .loc 1 658 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:658:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shfl.sync.idx.b32 %r9984, %r10, 0, 31, -1; + wgmma.fence.sync.aligned; + shl.b32 %r9985, %r9984, 11; + and.b32 %r9986, %r9985, 8192; + add.s32 %r8206, %r7390, 99328; + add.s32 %r9987, %r9986, %r8206; + bfe.u32 %r9988, %r9987, 4, 14; + cvt.u64.u32 %rd754, %r9988; + or.b64 %rd668, %rd754, 4611686293372403712; + bfe.u32 %r9989, %r8247, 4, 14; + cvt.u64.u32 %rd755, %r9989; + or.b64 %rd669, %rd755, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd668, %rd669, 0, 1, 1, 0, 0; + // end inline asm + or.b32 %r9990, %r9986, 32; + add.s32 %r9991, %r9990, %r8206; + bfe.u32 %r9992, %r9991, 4, 14; + cvt.u64.u32 %rd756, %r9992; + or.b64 %rd670, %rd756, 4611686293372403712; + add.s32 %r9993, %r8247, 32; + bfe.u32 %r9994, %r9993, 4, 14; + cvt.u64.u32 %rd757, %r9994; + or.b64 %rd671, %rd757, 4611686293338849280; + mov.pred %p663, -1; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd670, %rd671, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r9995, %r9986, 64; + add.s32 %r9996, %r9995, %r8206; + bfe.u32 %r9997, %r9996, 4, 14; + cvt.u64.u32 %rd758, %r9997; + or.b64 %rd672, %rd758, 4611686293372403712; + add.s32 %r9998, %r8247, 64; + bfe.u32 %r9999, %r9998, 4, 14; + cvt.u64.u32 %rd759, %r9999; + or.b64 %rd673, %rd759, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd672, %rd673, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r10000, %r9986, 96; + add.s32 %r10001, %r10000, %r8206; + bfe.u32 %r10002, %r10001, 4, 14; + cvt.u64.u32 %rd760, %r10002; + or.b64 %rd674, %rd760, 4611686293372403712; + add.s32 %r10003, %r8247, 96; + bfe.u32 %r10004, %r10003, 4, 14; + cvt.u64.u32 %rd761, %r10004; + or.b64 %rd675, %rd761, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd674, %rd675, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r10005, %r9986, 16384; + add.s32 %r10006, %r10005, %r8206; + bfe.u32 %r10007, %r10006, 4, 14; + cvt.u64.u32 %rd762, %r10007; + or.b64 %rd676, %rd762, 4611686293372403712; + add.s32 %r10008, %r8247, 8192; + bfe.u32 %r10009, %r10008, 4, 14; + cvt.u64.u32 %rd763, %r10009; + or.b64 %rd677, %rd763, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd676, %rd677, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r10010, %r9986, 16416; + add.s32 %r10011, %r10010, %r8206; + bfe.u32 %r10012, %r10011, 4, 14; + cvt.u64.u32 %rd764, %r10012; + or.b64 %rd678, %rd764, 4611686293372403712; + add.s32 %r10013, %r8247, 8224; + bfe.u32 %r10014, %r10013, 4, 14; + cvt.u64.u32 %rd765, %r10014; + or.b64 %rd679, %rd765, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd678, %rd679, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r10015, %r9986, 16448; + add.s32 %r10016, %r10015, %r8206; + bfe.u32 %r10017, %r10016, 4, 14; + cvt.u64.u32 %rd766, %r10017; + or.b64 %rd680, %rd766, 4611686293372403712; + add.s32 %r10018, %r8247, 8256; + bfe.u32 %r10019, %r10018, 4, 14; + cvt.u64.u32 %rd767, %r10019; + or.b64 %rd681, %rd767, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd680, %rd681, %p663, 1, 1, 0, 0; + // end inline asm + or.b32 %r10020, %r9986, 16480; + add.s32 %r10021, %r10020, %r8206; + bfe.u32 %r10022, %r10021, 4, 14; + cvt.u64.u32 %rd768, %r10022; + or.b64 %rd682, %rd768, 4611686293372403712; + add.s32 %r10023, %r8247, 8288; + bfe.u32 %r10024, %r10023, 4, 14; + cvt.u64.u32 %rd769, %r10024; + or.b64 %rd683, %rd769, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821}, %rd682, %rd683, %p663, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r8209, %r8247; + mov.b32 %r8207, %r8245; + mov.b32 %r8208, %r8245; + mov.b32 %r8210, %r8245; + mov.b32 %r8211, %r8245; + // begin inline asm + // wait for regs: %r7790,%r7791,%r7792,%r7793,%r7794,%r7795,%r7796,%r7797,%r7798,%r7799,%r7800,%r7801,%r7802,%r7803,%r7804,%r7805,%r7806,%r7807,%r7808,%r7809,%r7810,%r7811,%r7812,%r7813,%r7814,%r7815,%r7816,%r7817,%r7818,%r7819,%r7820,%r7821,%r8206,%r8207,%r8208,%r8209,%r8210,%r8211 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 660 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:660:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10025, %r7790, 0f3DB504F3; + mul.f32 %r10026, %r7791, 0f3DB504F3; + mul.f32 %r10027, %r7792, 0f3DB504F3; + mul.f32 %r10028, %r7793, 0f3DB504F3; + mul.f32 %r10029, %r7794, 0f3DB504F3; + mul.f32 %r10030, %r7795, 0f3DB504F3; + mul.f32 %r10031, %r7796, 0f3DB504F3; + mul.f32 %r10032, %r7797, 0f3DB504F3; + mul.f32 %r10033, %r7798, 0f3DB504F3; + mul.f32 %r10034, %r7799, 0f3DB504F3; + mul.f32 %r10035, %r7800, 0f3DB504F3; + mul.f32 %r10036, %r7801, 0f3DB504F3; + mul.f32 %r10037, %r7802, 0f3DB504F3; + mul.f32 %r10038, %r7803, 0f3DB504F3; + mul.f32 %r10039, %r7804, 0f3DB504F3; + mul.f32 %r10040, %r7805, 0f3DB504F3; + mul.f32 %r10041, %r7806, 0f3DB504F3; + mul.f32 %r10042, %r7807, 0f3DB504F3; + mul.f32 %r10043, %r7808, 0f3DB504F3; + mul.f32 %r10044, %r7809, 0f3DB504F3; + mul.f32 %r10045, %r7810, 0f3DB504F3; + mul.f32 %r10046, %r7811, 0f3DB504F3; + mul.f32 %r10047, %r7812, 0f3DB504F3; + mul.f32 %r10048, %r7813, 0f3DB504F3; + mul.f32 %r10049, %r7814, 0f3DB504F3; + mul.f32 %r10050, %r7815, 0f3DB504F3; + mul.f32 %r10051, %r7816, 0f3DB504F3; + mul.f32 %r10052, %r7817, 0f3DB504F3; + mul.f32 %r10053, %r7818, 0f3DB504F3; + mul.f32 %r10054, %r7819, 0f3DB504F3; + mul.f32 %r10055, %r7820, 0f3DB504F3; + mul.f32 %r10056, %r7821, 0f3DB504F3; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + rem.s32 %r10057, %r14806, %r2338; + rem.s32 %r10058, %r14805, %r2338; + rem.s32 %r10059, %r14804, %r2338; + rem.s32 %r10060, %r14803, %r2338; + rem.s32 %r10061, %r14802, %r2338; + rem.s32 %r10062, %r14801, %r2338; + rem.s32 %r10063, %r14800, %r2338; + rem.s32 %r10064, %r14799, %r2338; + rem.s32 %r10065, %r14798, %r2338; + rem.s32 %r10066, %r14797, %r2338; + rem.s32 %r10067, %r14796, %r2338; + rem.s32 %r10068, %r14795, %r2338; + rem.s32 %r10069, %r14794, %r2338; + rem.s32 %r10070, %r14793, %r2338; + rem.s32 %r10071, %r14792, %r2338; + rem.s32 %r10072, %r14791, %r2338; + .loc 1 679 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:679:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.u32 %r10073, %r10072, 31; + cvt.u16.u32 %rs129, %r10073; + shr.u32 %r10074, %r10071, 31; + cvt.u16.u32 %rs130, %r10074; + shl.b16 %rs131, %rs130, 1; + or.b16 %rs132, %rs129, %rs131; + shr.u32 %r10075, %r10070, 31; + cvt.u16.u32 %rs133, %r10075; + shl.b16 %rs134, %rs133, 2; + shr.u32 %r10076, %r10069, 31; + cvt.u16.u32 %rs135, %r10076; + shl.b16 %rs136, %rs135, 3; + or.b16 %rs137, %rs136, %rs134; + or.b16 %rs138, %rs132, %rs137; + shr.u32 %r10077, %r10068, 31; + cvt.u16.u32 %rs139, %r10077; + shr.u32 %r10078, %r10067, 31; + cvt.u16.u32 %rs140, %r10078; + shl.b16 %rs141, %rs140, 1; + or.b16 %rs142, %rs139, %rs141; + shr.u32 %r10079, %r10066, 31; + cvt.u16.u32 %rs143, %r10079; + shl.b16 %rs144, %rs143, 2; + shr.u32 %r10080, %r10065, 31; + cvt.u16.u32 %rs145, %r10080; + shl.b16 %rs146, %rs145, 3; + or.b16 %rs147, %rs146, %rs144; + or.b16 %rs148, %rs142, %rs147; + shl.b16 %rs149, %rs148, 4; + or.b16 %rs150, %rs138, %rs149; + shr.u32 %r10081, %r10064, 31; + cvt.u16.u32 %rs151, %r10081; + shr.u32 %r10082, %r10063, 31; + cvt.u16.u32 %rs152, %r10082; + shl.b16 %rs153, %rs152, 1; + or.b16 %rs154, %rs151, %rs153; + shr.u32 %r10083, %r10062, 31; + cvt.u16.u32 %rs155, %r10083; + shl.b16 %rs156, %rs155, 2; + shr.u32 %r10084, %r10061, 31; + cvt.u16.u32 %rs157, %r10084; + shl.b16 %rs158, %rs157, 3; + or.b16 %rs159, %rs158, %rs156; + or.b16 %rs160, %rs154, %rs159; + shl.b16 %rs161, %rs160, 8; + shr.u32 %r10085, %r10060, 31; + cvt.u16.u32 %rs162, %r10085; + shr.u32 %r10086, %r10059, 31; + cvt.u16.u32 %rs163, %r10086; + shl.b16 %rs164, %rs163, 1; + or.b16 %rs165, %rs162, %rs164; + shr.u32 %r10087, %r10058, 31; + cvt.u16.u32 %rs166, %r10087; + shl.b16 %rs167, %rs166, 2; + shr.u32 %r10088, %r10057, 31; + cvt.u16.u32 %rs168, %r10088; + shl.b16 %rs169, %rs168, 3; + or.b16 %rs170, %rs169, %rs167; + or.b16 %rs171, %rs165, %rs170; + shl.b16 %rs172, %rs171, 12; + or.b16 %rs173, %rs172, %rs161; + setp.lt.s32 %p754, %r10068, 0; + setp.lt.s32 %p755, %r10067, 0; + setp.lt.s32 %p756, %r10066, 0; + setp.lt.s32 %p757, %r10065, 0; + setp.lt.s32 %p758, %r10071, 0; + setp.lt.s32 %p759, %r10070, 0; + setp.lt.s32 %p760, %r10069, 0; + setp.lt.s32 %p761, %r10064, 0; + setp.lt.s32 %p762, %r10063, 0; + setp.lt.s32 %p763, %r10062, 0; + setp.lt.s32 %p764, %r10061, 0; + setp.lt.s32 %p765, %r10060, 0; + setp.lt.s32 %p766, %r10059, 0; + setp.lt.s32 %p767, %r10058, 0; + setp.lt.s32 %p768, %r10057, 0; + setp.lt.s32 %p769, %r10072, 0; + .loc 1 681 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:681:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.le.s32 %p770, %r1002, %r10057; + setp.le.s32 %p771, %r1002, %r10058; + setp.le.s32 %p772, %r1003, %r10057; + setp.le.s32 %p773, %r1003, %r10058; + setp.le.s32 %p774, %r1002, %r10059; + setp.le.s32 %p775, %r1002, %r10060; + setp.le.s32 %p776, %r1003, %r10059; + setp.le.s32 %p777, %r1003, %r10060; + setp.le.s32 %p778, %r1002, %r10061; + setp.le.s32 %p779, %r1002, %r10062; + setp.le.s32 %p780, %r1003, %r10061; + setp.le.s32 %p781, %r1003, %r10062; + setp.le.s32 %p782, %r1002, %r10063; + setp.le.s32 %p783, %r1002, %r10064; + setp.le.s32 %p784, %r1003, %r10063; + setp.le.s32 %p785, %r1003, %r10064; + setp.le.s32 %p786, %r1002, %r10065; + setp.le.s32 %p787, %r1002, %r10066; + setp.le.s32 %p788, %r1003, %r10065; + setp.le.s32 %p789, %r1003, %r10066; + setp.le.s32 %p790, %r1002, %r10067; + setp.le.s32 %p791, %r1002, %r10068; + setp.le.s32 %p792, %r1003, %r10067; + setp.le.s32 %p793, %r1003, %r10068; + setp.le.s32 %p794, %r1002, %r10069; + setp.le.s32 %p795, %r1002, %r10070; + setp.le.s32 %p796, %r1003, %r10069; + setp.le.s32 %p797, %r1003, %r10070; + setp.le.s32 %p798, %r1002, %r10071; + setp.le.s32 %p799, %r1002, %r10072; + setp.le.s32 %p800, %r1003, %r10071; + setp.le.s32 %p801, %r1003, %r10072; + .loc 1 682 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:682:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.u16 %rs174, %rs173, 15; + and.b16 %rs175, %rs174, 1; + setp.ne.b16 %p802, %rs175, 0; + and.pred %p803, %p802, %p770; + shr.u16 %rs176, %rs173, 14; + and.b16 %rs177, %rs176, 1; + setp.ne.b16 %p804, %rs177, 0; + and.pred %p805, %p804, %p771; + and.pred %p806, %p802, %p772; + and.pred %p807, %p804, %p773; + shr.u16 %rs178, %rs173, 13; + and.b16 %rs179, %rs178, 1; + setp.ne.b16 %p808, %rs179, 0; + and.pred %p809, %p808, %p774; + shr.u16 %rs180, %rs173, 12; + and.b16 %rs181, %rs180, 1; + setp.ne.b16 %p810, %rs181, 0; + and.pred %p811, %p810, %p775; + and.pred %p812, %p808, %p776; + and.pred %p813, %p810, %p777; + and.b32 %r10089, %r10084, 1; + setp.ne.b32 %p814, %r10089, 0; + and.pred %p815, %p814, %p778; + and.b32 %r10090, %r10083, 1; + setp.ne.b32 %p816, %r10090, 0; + and.pred %p817, %p816, %p779; + and.pred %p818, %p814, %p780; + and.pred %p819, %p816, %p781; + and.b32 %r10091, %r10082, 1; + setp.ne.b32 %p820, %r10091, 0; + and.pred %p821, %p820, %p782; + and.b32 %r10092, %r10081, 1; + setp.ne.b32 %p822, %r10092, 0; + and.pred %p823, %p822, %p783; + and.pred %p824, %p820, %p784; + and.pred %p825, %p822, %p785; + shr.u16 %rs182, %rs150, 7; + and.b16 %rs183, %rs182, 1; + setp.ne.b16 %p826, %rs183, 0; + and.pred %p827, %p826, %p786; + shr.u16 %rs184, %rs150, 6; + and.b16 %rs185, %rs184, 1; + setp.ne.b16 %p828, %rs185, 0; + and.pred %p829, %p828, %p787; + and.pred %p830, %p826, %p788; + and.pred %p831, %p828, %p789; + shr.u16 %rs186, %rs150, 5; + and.b16 %rs187, %rs186, 1; + setp.ne.b16 %p832, %rs187, 0; + and.pred %p833, %p832, %p790; + shr.u16 %rs188, %rs150, 4; + and.b16 %rs189, %rs188, 1; + setp.ne.b16 %p834, %rs189, 0; + and.pred %p835, %p834, %p791; + and.pred %p836, %p832, %p792; + and.pred %p837, %p834, %p793; + and.b32 %r10093, %r10076, 1; + setp.ne.b32 %p838, %r10093, 0; + and.pred %p839, %p838, %p794; + and.b32 %r10094, %r10075, 1; + setp.ne.b32 %p840, %r10094, 0; + and.pred %p841, %p840, %p795; + and.pred %p842, %p838, %p796; + and.pred %p843, %p840, %p797; + and.b32 %r10095, %r10074, 1; + setp.ne.b32 %p844, %r10095, 0; + and.pred %p845, %p844, %p798; + and.pred %p846, %p769, %p799; + and.pred %p847, %p844, %p800; + and.pred %p848, %p769, %p801; + .loc 1 683 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:683:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.gt.s32 %p849, %r10057, -1; + setp.gt.s32 %p850, %r10058, -1; + setp.gt.s32 %p851, %r10059, -1; + setp.gt.s32 %p852, %r10060, -1; + setp.gt.s32 %p853, %r10061, -1; + setp.gt.s32 %p854, %r10062, -1; + setp.gt.s32 %p855, %r10063, -1; + setp.gt.s32 %p856, %r10064, -1; + setp.gt.s32 %p857, %r10065, -1; + setp.gt.s32 %p858, %r10066, -1; + setp.gt.s32 %p859, %r10067, -1; + setp.gt.s32 %p860, %r10068, -1; + setp.gt.s32 %p861, %r10069, -1; + setp.gt.s32 %p862, %r10070, -1; + setp.gt.s32 %p863, %r10071, -1; + setp.gt.s32 %p864, %r10072, -1; + .loc 1 690 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:690:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.b32 %r10096, %r10057, 15; + and.b32 %r10097, %r10058, 15; + and.b32 %r10098, %r10059, 15; + and.b32 %r10099, %r10060, 15; + and.b32 %r10100, %r10061, 15; + and.b32 %r10101, %r10062, 15; + and.b32 %r10102, %r10063, 15; + and.b32 %r10103, %r10064, 15; + and.b32 %r10104, %r10069, 15; + and.b32 %r10105, %r10070, 15; + and.b32 %r10106, %r10071, 15; + and.b32 %r10107, %r10072, 15; + and.b32 %r10108, %r10065, 15; + and.b32 %r10109, %r10066, 15; + and.b32 %r10110, %r10067, 15; + and.b32 %r10111, %r10068, 15; + setp.ne.b32 %p865, %r10111, 0; + setp.ne.b32 %p866, %r10110, 0; + setp.ne.b32 %p867, %r10109, 0; + setp.ne.b32 %p868, %r10108, 0; + setp.ne.b32 %p869, %r10107, 0; + setp.ne.b32 %p870, %r10106, 0; + setp.ne.b32 %p871, %r10105, 0; + setp.ne.b32 %p872, %r10104, 0; + setp.ne.b32 %p873, %r10103, 0; + setp.ne.b32 %p874, %r10102, 0; + setp.ne.b32 %p875, %r10101, 0; + setp.ne.b32 %p876, %r10100, 0; + setp.ne.b32 %p877, %r10099, 0; + setp.ne.b32 %p878, %r10098, 0; + setp.ne.b32 %p879, %r10097, 0; + setp.ne.b32 %p880, %r10096, 0; + .loc 1 690 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:690:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.s32 %r10112, %r10068, 31; + shr.u32 %r10113, %r10112, 28; + add.s32 %r10114, %r10068, %r10113; + shr.s32 %r10115, %r10114, 4; + shr.s32 %r10116, %r10067, 31; + shr.u32 %r10117, %r10116, 28; + add.s32 %r10118, %r10067, %r10117; + shr.s32 %r10119, %r10118, 4; + shr.s32 %r10120, %r10066, 31; + shr.u32 %r10121, %r10120, 28; + add.s32 %r10122, %r10066, %r10121; + shr.s32 %r10123, %r10122, 4; + shr.s32 %r10124, %r10065, 31; + shr.u32 %r10125, %r10124, 28; + add.s32 %r10126, %r10065, %r10125; + shr.s32 %r10127, %r10126, 4; + shr.s32 %r10128, %r10072, 31; + shr.u32 %r10129, %r10128, 28; + add.s32 %r10130, %r10072, %r10129; + shr.s32 %r10131, %r10130, 4; + shr.s32 %r10132, %r10071, 31; + shr.u32 %r10133, %r10132, 28; + add.s32 %r10134, %r10071, %r10133; + shr.s32 %r10135, %r10134, 4; + shr.s32 %r10136, %r10070, 31; + shr.u32 %r10137, %r10136, 28; + add.s32 %r10138, %r10070, %r10137; + shr.s32 %r10139, %r10138, 4; + shr.s32 %r10140, %r10069, 31; + shr.u32 %r10141, %r10140, 28; + add.s32 %r10142, %r10069, %r10141; + shr.s32 %r10143, %r10142, 4; + shr.s32 %r10144, %r10064, 31; + shr.u32 %r10145, %r10144, 28; + add.s32 %r10146, %r10064, %r10145; + shr.s32 %r10147, %r10146, 4; + shr.s32 %r10148, %r10063, 31; + shr.u32 %r10149, %r10148, 28; + add.s32 %r10150, %r10063, %r10149; + shr.s32 %r10151, %r10150, 4; + shr.s32 %r10152, %r10062, 31; + shr.u32 %r10153, %r10152, 28; + add.s32 %r10154, %r10062, %r10153; + shr.s32 %r10155, %r10154, 4; + shr.s32 %r10156, %r10061, 31; + shr.u32 %r10157, %r10156, 28; + add.s32 %r10158, %r10061, %r10157; + shr.s32 %r10159, %r10158, 4; + shr.s32 %r10160, %r10060, 31; + shr.u32 %r10161, %r10160, 28; + add.s32 %r10162, %r10060, %r10161; + shr.s32 %r10163, %r10162, 4; + shr.s32 %r10164, %r10059, 31; + shr.u32 %r10165, %r10164, 28; + add.s32 %r10166, %r10059, %r10165; + shr.s32 %r10167, %r10166, 4; + shr.s32 %r10168, %r10058, 31; + shr.u32 %r10169, %r10168, 28; + add.s32 %r10170, %r10058, %r10169; + shr.s32 %r10171, %r10170, 4; + shr.s32 %r10172, %r10057, 31; + shr.u32 %r10173, %r10172, 28; + add.s32 %r10174, %r10057, %r10173; + shr.s32 %r10175, %r10174, 4; + .loc 1 690 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:690:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.pred %p881, %p768, %p880; + and.pred %p882, %p767, %p879; + and.pred %p883, %p766, %p878; + and.pred %p884, %p765, %p877; + and.pred %p885, %p764, %p876; + and.pred %p886, %p763, %p875; + and.pred %p887, %p762, %p874; + and.pred %p888, %p761, %p873; + and.pred %p889, %p760, %p872; + and.pred %p890, %p759, %p871; + and.pred %p891, %p758, %p870; + and.pred %p892, %p769, %p869; + and.pred %p893, %p757, %p868; + and.pred %p894, %p756, %p867; + and.pred %p895, %p755, %p866; + and.pred %p896, %p754, %p865; + selp.b32 %r10176, -1, 0, %p896; + selp.b32 %r10177, -1, 0, %p895; + selp.b32 %r10178, -1, 0, %p894; + selp.b32 %r10179, -1, 0, %p893; + selp.b32 %r10180, -1, 0, %p892; + selp.b32 %r10181, -1, 0, %p891; + selp.b32 %r10182, -1, 0, %p890; + selp.b32 %r10183, -1, 0, %p889; + selp.b32 %r10184, -1, 0, %p888; + selp.b32 %r10185, -1, 0, %p887; + selp.b32 %r10186, -1, 0, %p886; + selp.b32 %r10187, -1, 0, %p885; + selp.b32 %r10188, -1, 0, %p884; + selp.b32 %r10189, -1, 0, %p883; + selp.b32 %r10190, -1, 0, %p882; + selp.b32 %r10191, -1, 0, %p881; + add.s32 %r10192, %r10175, %r10191; + add.s32 %r10193, %r10171, %r10190; + add.s32 %r10194, %r10167, %r10189; + add.s32 %r10195, %r10163, %r10188; + add.s32 %r10196, %r10159, %r10187; + add.s32 %r10197, %r10155, %r10186; + add.s32 %r10198, %r10151, %r10185; + add.s32 %r10199, %r10147, %r10184; + add.s32 %r10200, %r10143, %r10183; + add.s32 %r10201, %r10139, %r10182; + add.s32 %r10202, %r10135, %r10181; + add.s32 %r10203, %r10131, %r10180; + add.s32 %r10204, %r10127, %r10179; + add.s32 %r10205, %r10123, %r10178; + add.s32 %r10206, %r10119, %r10177; + add.s32 %r10207, %r10115, %r10176; + .loc 1 693 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:693:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.eq.b32 %p897, %r10207, %r741; + selp.b16 %rs190, 1, 0, %p897; + shl.b16 %rs191, %rs190, 2; + setp.eq.b32 %p898, %r10206, %r741; + selp.b16 %rs192, -1, 0, %p898; + shl.b16 %rs193, %rs192, 3; + or.b16 %rs194, %rs193, %rs191; + setp.eq.b32 %p899, %r10207, %r739; + selp.b16 %rs195, 1, 0, %p899; + setp.eq.b32 %p900, %r10206, %r739; + selp.b16 %rs196, -1, 0, %p900; + shl.b16 %rs197, %rs196, 1; + or.b16 %rs198, %rs195, %rs197; + and.b16 %rs199, %rs198, 3; + or.b16 %rs200, %rs199, %rs194; + and.b16 %rs201, %rs200, 15; + shl.b16 %rs202, %rs201, 8; + setp.eq.b32 %p901, %r10205, %r741; + selp.b16 %rs203, 1, 0, %p901; + shl.b16 %rs204, %rs203, 2; + setp.eq.b32 %p902, %r10204, %r741; + selp.b16 %rs205, -1, 0, %p902; + shl.b16 %rs206, %rs205, 3; + or.b16 %rs207, %rs206, %rs204; + setp.eq.b32 %p903, %r10205, %r739; + selp.b16 %rs208, 1, 0, %p903; + setp.eq.b32 %p904, %r10204, %r739; + selp.b16 %rs209, -1, 0, %p904; + shl.b16 %rs210, %rs209, 1; + or.b16 %rs211, %rs208, %rs210; + and.b16 %rs212, %rs211, 3; + or.b16 %rs213, %rs212, %rs207; + shl.b16 %rs214, %rs213, 12; + or.b16 %rs215, %rs214, %rs202; + setp.eq.b32 %p905, %r10203, %r741; + selp.b16 %rs216, 1, 0, %p905; + shl.b16 %rs217, %rs216, 2; + setp.eq.b32 %p906, %r10202, %r741; + selp.b16 %rs218, -1, 0, %p906; + shl.b16 %rs219, %rs218, 3; + or.b16 %rs220, %rs219, %rs217; + setp.eq.b32 %p907, %r10203, %r739; + selp.b16 %rs221, 1, 0, %p907; + setp.eq.b32 %p908, %r10202, %r739; + selp.b16 %rs222, -1, 0, %p908; + shl.b16 %rs223, %rs222, 1; + or.b16 %rs224, %rs221, %rs223; + and.b16 %rs225, %rs224, 3; + or.b16 %rs226, %rs225, %rs220; + and.b16 %rs227, %rs226, 15; + setp.eq.b32 %p909, %r10201, %r741; + selp.b16 %rs228, 1, 0, %p909; + shl.b16 %rs229, %rs228, 2; + setp.eq.b32 %p910, %r10200, %r741; + selp.b16 %rs230, -1, 0, %p910; + shl.b16 %rs231, %rs230, 3; + or.b16 %rs232, %rs231, %rs229; + setp.eq.b32 %p911, %r10201, %r739; + selp.b16 %rs233, 1, 0, %p911; + setp.eq.b32 %p912, %r10200, %r739; + selp.b16 %rs234, -1, 0, %p912; + shl.b16 %rs235, %rs234, 1; + or.b16 %rs236, %rs233, %rs235; + and.b16 %rs237, %rs236, 3; + or.b16 %rs238, %rs237, %rs232; + shl.b16 %rs239, %rs238, 4; + or.b16 %rs240, %rs227, %rs239; + and.b16 %rs241, %rs240, 255; + or.b16 %rs242, %rs241, %rs215; + cvt.u32.u16 %r10208, %rs242; + setp.eq.b32 %p913, %r10199, %r739; + setp.eq.b32 %p914, %r10198, %r739; + setp.eq.b32 %p915, %r10199, %r741; + setp.eq.b32 %p916, %r10198, %r741; + setp.eq.b32 %p917, %r10197, %r739; + setp.eq.b32 %p918, %r10196, %r739; + setp.eq.b32 %p919, %r10197, %r741; + setp.eq.b32 %p920, %r10196, %r741; + setp.eq.b32 %p921, %r10195, %r739; + setp.eq.b32 %p922, %r10194, %r739; + setp.eq.b32 %p923, %r10195, %r741; + setp.eq.b32 %p924, %r10194, %r741; + setp.eq.b32 %p925, %r10193, %r739; + setp.eq.b32 %p926, %r10192, %r739; + setp.eq.b32 %p927, %r10193, %r741; + setp.eq.b32 %p928, %r10192, %r741; + .loc 1 694 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:694:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.pred %p929, %p849, %p928; + and.pred %p930, %p850, %p927; + and.pred %p931, %p849, %p926; + and.pred %p932, %p850, %p925; + and.pred %p933, %p851, %p924; + and.pred %p934, %p852, %p923; + and.pred %p935, %p851, %p922; + and.pred %p936, %p852, %p921; + and.pred %p937, %p853, %p920; + and.pred %p938, %p854, %p919; + and.pred %p939, %p853, %p918; + and.pred %p940, %p854, %p917; + and.pred %p941, %p855, %p916; + and.pred %p942, %p856, %p915; + and.pred %p943, %p855, %p914; + and.pred %p944, %p856, %p913; + shr.u32 %r10209, %r10208, 15; + and.b32 %r10210, %r10209, 1; + setp.ne.b32 %p945, %r10210, 0; + and.pred %p946, %p857, %p945; + shr.u32 %r10211, %r10208, 14; + and.b32 %r10212, %r10211, 1; + setp.ne.b32 %p947, %r10212, 0; + and.pred %p948, %p858, %p947; + shr.u32 %r10213, %r10208, 13; + and.b32 %r10214, %r10213, 1; + setp.ne.b32 %p949, %r10214, 0; + and.pred %p950, %p857, %p949; + shr.u32 %r10215, %r10208, 12; + and.b32 %r10216, %r10215, 1; + setp.ne.b32 %p951, %r10216, 0; + and.pred %p952, %p858, %p951; + shr.u32 %r10217, %r10208, 11; + and.b32 %r10218, %r10217, 1; + setp.ne.b32 %p953, %r10218, 0; + and.pred %p954, %p859, %p953; + shr.u32 %r10219, %r10208, 10; + and.b32 %r10220, %r10219, 1; + setp.ne.b32 %p955, %r10220, 0; + and.pred %p956, %p860, %p955; + shr.u32 %r10221, %r10208, 9; + and.b32 %r10222, %r10221, 1; + setp.ne.b32 %p957, %r10222, 0; + and.pred %p958, %p859, %p957; + shr.u32 %r10223, %r10208, 8; + and.b32 %r10224, %r10223, 1; + setp.ne.b32 %p959, %r10224, 0; + and.pred %p960, %p860, %p959; + shr.u32 %r10225, %r10208, 7; + and.b32 %r10226, %r10225, 1; + setp.ne.b32 %p961, %r10226, 0; + and.pred %p962, %p861, %p961; + shr.u32 %r10227, %r10208, 6; + and.b32 %r10228, %r10227, 1; + setp.ne.b32 %p963, %r10228, 0; + and.pred %p964, %p862, %p963; + shr.u32 %r10229, %r10208, 5; + and.b32 %r10230, %r10229, 1; + setp.ne.b32 %p965, %r10230, 0; + and.pred %p966, %p861, %p965; + shr.u32 %r10231, %r10208, 4; + and.b32 %r10232, %r10231, 1; + setp.ne.b32 %p967, %r10232, 0; + and.pred %p968, %p862, %p967; + shr.u32 %r10233, %r10208, 3; + and.b32 %r10234, %r10233, 1; + setp.ne.b32 %p969, %r10234, 0; + and.pred %p970, %p863, %p969; + shr.u32 %r10235, %r10208, 2; + and.b32 %r10236, %r10235, 1; + setp.ne.b32 %p971, %r10236, 0; + and.pred %p972, %p864, %p971; + shr.u32 %r10237, %r10208, 1; + and.b32 %r10238, %r10237, 1; + setp.ne.b32 %p973, %r10238, 0; + and.pred %p974, %p863, %p973; + and.pred %p975, %p864, %p907; + .loc 1 696 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:696:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + or.pred %p976, %p803, %p929; + or.pred %p977, %p805, %p930; + or.pred %p978, %p806, %p931; + or.pred %p979, %p807, %p932; + or.pred %p980, %p809, %p933; + or.pred %p981, %p811, %p934; + or.pred %p982, %p812, %p935; + or.pred %p983, %p813, %p936; + or.pred %p984, %p815, %p937; + or.pred %p985, %p817, %p938; + or.pred %p986, %p818, %p939; + or.pred %p987, %p819, %p940; + or.pred %p988, %p821, %p941; + or.pred %p989, %p823, %p942; + or.pred %p990, %p824, %p943; + or.pred %p991, %p825, %p944; + or.pred %p992, %p827, %p946; + or.pred %p993, %p829, %p948; + or.pred %p994, %p830, %p950; + or.pred %p995, %p831, %p952; + or.pred %p996, %p833, %p954; + or.pred %p997, %p835, %p956; + or.pred %p998, %p836, %p958; + or.pred %p999, %p837, %p960; + or.pred %p1000, %p839, %p962; + or.pred %p1001, %p841, %p964; + or.pred %p1002, %p842, %p966; + or.pred %p1003, %p843, %p968; + or.pred %p1004, %p845, %p970; + or.pred %p1005, %p846, %p972; + or.pred %p1006, %p847, %p974; + or.pred %p1007, %p848, %p975; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.pred %p1008, %p976, %p722; + and.pred %p1009, %p977, %p723; + and.pred %p1010, %p978, %p722; + and.pred %p1011, %p979, %p723; + and.pred %p1012, %p980, %p724; + and.pred %p1013, %p981, %p725; + and.pred %p1014, %p982, %p724; + and.pred %p1015, %p983, %p725; + and.pred %p1016, %p984, %p726; + and.pred %p1017, %p985, %p727; + and.pred %p1018, %p986, %p726; + and.pred %p1019, %p987, %p727; + and.pred %p1020, %p988, %p728; + and.pred %p1021, %p989, %p729; + and.pred %p1022, %p990, %p728; + and.pred %p1023, %p991, %p729; + and.pred %p1024, %p992, %p730; + and.pred %p1025, %p993, %p731; + and.pred %p1026, %p994, %p730; + and.pred %p1027, %p995, %p731; + and.pred %p1028, %p996, %p732; + and.pred %p1029, %p997, %p733; + and.pred %p1030, %p998, %p732; + and.pred %p1031, %p999, %p733; + and.pred %p1032, %p1000, %p734; + and.pred %p1033, %p1001, %p735; + and.pred %p1034, %p1002, %p734; + and.pred %p1035, %p1003, %p735; + and.pred %p1036, %p1004, %p736; + and.pred %p1037, %p1005, %p737; + and.pred %p1038, %p1006, %p736; + and.pred %p1039, %p1007, %p737; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10239, %r10025, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10240, %r10239, 0fFF800000, %p1008; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10241, %r10026, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10242, %r10241, 0fFF800000, %p1009; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10243, %r10027, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10244, %r10243, 0fFF800000, %p1010; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10245, %r10028, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10246, %r10245, 0fFF800000, %p1011; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10247, %r10029, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10248, %r10247, 0fFF800000, %p1012; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10249, %r10030, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10250, %r10249, 0fFF800000, %p1013; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10251, %r10031, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10252, %r10251, 0fFF800000, %p1014; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10253, %r10032, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10254, %r10253, 0fFF800000, %p1015; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10255, %r10033, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10256, %r10255, 0fFF800000, %p1016; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10257, %r10034, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10258, %r10257, 0fFF800000, %p1017; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10259, %r10035, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10260, %r10259, 0fFF800000, %p1018; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10261, %r10036, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10262, %r10261, 0fFF800000, %p1019; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10263, %r10037, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10264, %r10263, 0fFF800000, %p1020; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10265, %r10038, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10266, %r10265, 0fFF800000, %p1021; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10267, %r10039, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10268, %r10267, 0fFF800000, %p1022; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10269, %r10040, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10270, %r10269, 0fFF800000, %p1023; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10271, %r10041, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10272, %r10271, 0fFF800000, %p1024; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10273, %r10042, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10274, %r10273, 0fFF800000, %p1025; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10275, %r10043, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10276, %r10275, 0fFF800000, %p1026; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10277, %r10044, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10278, %r10277, 0fFF800000, %p1027; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10279, %r10045, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10280, %r10279, 0fFF800000, %p1028; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10281, %r10046, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10282, %r10281, 0fFF800000, %p1029; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10283, %r10047, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10284, %r10283, 0fFF800000, %p1030; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10285, %r10048, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10286, %r10285, 0fFF800000, %p1031; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10287, %r10049, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10288, %r10287, 0fFF800000, %p1032; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10289, %r10050, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10290, %r10289, 0fFF800000, %p1033; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10291, %r10051, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10292, %r10291, 0fFF800000, %p1034; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10293, %r10052, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10294, %r10293, 0fFF800000, %p1035; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10295, %r10053, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10296, %r10295, 0fFF800000, %p1036; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10297, %r10054, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10298, %r10297, 0fFF800000, %p1037; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10299, %r10055, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10300, %r10299, 0fFF800000, %p1038; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10301, %r10056, 0f3FB8AA3B; + .loc 1 700 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:700:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.f32 %r10302, %r10301, 0fFF800000, %p1039; + .loc 1 704 40 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:704:40 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + sub.f32 %r10303, %r10240, %r9968; + sub.f32 %r10304, %r10242, %r9969; + sub.f32 %r10305, %r10244, %r9968; + sub.f32 %r10306, %r10246, %r9969; + sub.f32 %r10307, %r10248, %r9970; + sub.f32 %r10308, %r10250, %r9971; + sub.f32 %r10309, %r10252, %r9970; + sub.f32 %r10310, %r10254, %r9971; + sub.f32 %r10311, %r10256, %r9972; + sub.f32 %r10312, %r10258, %r9973; + sub.f32 %r10313, %r10260, %r9972; + sub.f32 %r10314, %r10262, %r9973; + sub.f32 %r10315, %r10264, %r9974; + sub.f32 %r10316, %r10266, %r9975; + sub.f32 %r10317, %r10268, %r9974; + sub.f32 %r10318, %r10270, %r9975; + sub.f32 %r10319, %r10272, %r9976; + sub.f32 %r10320, %r10274, %r9977; + sub.f32 %r10321, %r10276, %r9976; + sub.f32 %r10322, %r10278, %r9977; + sub.f32 %r10323, %r10280, %r9978; + sub.f32 %r10324, %r10282, %r9979; + sub.f32 %r10325, %r10284, %r9978; + sub.f32 %r10326, %r10286, %r9979; + sub.f32 %r10327, %r10288, %r9980; + sub.f32 %r10328, %r10290, %r9981; + sub.f32 %r10329, %r10292, %r9980; + sub.f32 %r10330, %r10294, %r9981; + sub.f32 %r10331, %r10296, %r9982; + sub.f32 %r10332, %r10298, %r9983; + sub.f32 %r10333, %r10300, %r9982; + sub.f32 %r10334, %r10302, %r9983; + .loc 1 704 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:704:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + ex2.approx.ftz.f32 %r10335, %r10303; + ex2.approx.ftz.f32 %r10336, %r10304; + ex2.approx.ftz.f32 %r10337, %r10305; + ex2.approx.ftz.f32 %r10338, %r10306; + ex2.approx.ftz.f32 %r10339, %r10307; + ex2.approx.ftz.f32 %r10340, %r10308; + ex2.approx.ftz.f32 %r10341, %r10309; + ex2.approx.ftz.f32 %r10342, %r10310; + ex2.approx.ftz.f32 %r10343, %r10311; + ex2.approx.ftz.f32 %r10344, %r10312; + ex2.approx.ftz.f32 %r10345, %r10313; + ex2.approx.ftz.f32 %r10346, %r10314; + ex2.approx.ftz.f32 %r10347, %r10315; + ex2.approx.ftz.f32 %r10348, %r10316; + ex2.approx.ftz.f32 %r10349, %r10317; + ex2.approx.ftz.f32 %r10350, %r10318; + ex2.approx.ftz.f32 %r10351, %r10319; + ex2.approx.ftz.f32 %r10352, %r10320; + ex2.approx.ftz.f32 %r10353, %r10321; + ex2.approx.ftz.f32 %r10354, %r10322; + ex2.approx.ftz.f32 %r10355, %r10323; + ex2.approx.ftz.f32 %r10356, %r10324; + ex2.approx.ftz.f32 %r10357, %r10325; + ex2.approx.ftz.f32 %r10358, %r10326; + ex2.approx.ftz.f32 %r10359, %r10327; + ex2.approx.ftz.f32 %r10360, %r10328; + ex2.approx.ftz.f32 %r10361, %r10329; + ex2.approx.ftz.f32 %r10362, %r10330; + ex2.approx.ftz.f32 %r10363, %r10331; + ex2.approx.ftz.f32 %r10364, %r10332; + ex2.approx.ftz.f32 %r10365, %r10333; + ex2.approx.ftz.f32 %r10366, %r10334; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10367, %r7390, 49152; + add.s32 %r9293, %r10367, %r9946; + .loc 1 708 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:708:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16x2.f32 %r8378, %r10336, %r10335; + cvt.rn.bf16x2.f32 %r8379, %r10338, %r10337; + cvt.rn.bf16x2.f32 %r8380, %r10340, %r10339; + cvt.rn.bf16x2.f32 %r8381, %r10342, %r10341; + cvt.rn.bf16x2.f32 %r8510, %r10344, %r10343; + cvt.rn.bf16x2.f32 %r8511, %r10346, %r10345; + cvt.rn.bf16x2.f32 %r8512, %r10348, %r10347; + cvt.rn.bf16x2.f32 %r8513, %r10350, %r10349; + cvt.rn.bf16x2.f32 %r8642, %r10352, %r10351; + cvt.rn.bf16x2.f32 %r8643, %r10354, %r10353; + cvt.rn.bf16x2.f32 %r8644, %r10356, %r10355; + cvt.rn.bf16x2.f32 %r8645, %r10358, %r10357; + cvt.rn.bf16x2.f32 %r8774, %r10360, %r10359; + cvt.rn.bf16x2.f32 %r8775, %r10362, %r10361; + cvt.rn.bf16x2.f32 %r8776, %r10364, %r10363; + cvt.rn.bf16x2.f32 %r8777, %r10366, %r10365; + .loc 1 708 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:708:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + wgmma.fence.sync.aligned; + bfe.u32 %r10368, %r9293, 4, 14; + cvt.u64.u32 %rd770, %r10368; + or.b64 %rd684, %rd770, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r8378,%r8379,%r8380,%r8381}, %rd684, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10369, %r9293, 2048; + bfe.u32 %r10370, %r10369, 4, 14; + cvt.u64.u32 %rd771, %r10370; + or.b64 %rd685, %rd771, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r8510,%r8511,%r8512,%r8513}, %rd685, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10371, %r9293, 4096; + bfe.u32 %r10372, %r10371, 4, 14; + cvt.u64.u32 %rd772, %r10372; + or.b64 %rd686, %rd772, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r8642,%r8643,%r8644,%r8645}, %rd686, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10373, %r9293, 6144; + bfe.u32 %r10374, %r10373, 4, 14; + cvt.u64.u32 %rd773, %r10374; + or.b64 %rd687, %rd773, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r8774,%r8775,%r8776,%r8777}, %rd687, %p663, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10375, %r7390, 98816; + add.s32 %r10376, %r10375, %r9948; + add.s32 %r10377, %r10376, %r787; + ld.shared.v2.b32 {%r10378, %r10379}, [%r10377]; + ld.shared.v2.b32 {%r10380, %r10381}, [%r10377+32]; + ld.shared.v2.b32 {%r10382, %r10383}, [%r10377+64]; + ld.shared.v2.b32 {%r10384, %r10385}, [%r10377+96]; + ld.shared.v2.b32 {%r10386, %r10387}, [%r10377+128]; + ld.shared.v2.b32 {%r10388, %r10389}, [%r10377+160]; + ld.shared.v2.b32 {%r10390, %r10391}, [%r10377+192]; + ld.shared.v2.b32 {%r10392, %r10393}, [%r10377+224]; + .loc 1 714 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:714:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + wgmma.fence.sync.aligned; + add.s32 %r9290, %r7390, 132096; + add.s32 %r10394, %r9986, %r9290; + bfe.u32 %r10395, %r10394, 4, 14; + cvt.u64.u32 %rd774, %r10395; + or.b64 %rd688, %rd774, 4611686293372403712; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd688, %rd684, 0, 1, 1, 0, 0; + // end inline asm + add.s32 %r10396, %r9990, %r9290; + bfe.u32 %r10397, %r10396, 4, 14; + cvt.u64.u32 %rd775, %r10397; + or.b64 %rd690, %rd775, 4611686293372403712; + add.s32 %r10398, %r9293, 32; + bfe.u32 %r10399, %r10398, 4, 14; + cvt.u64.u32 %rd776, %r10399; + or.b64 %rd691, %rd776, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd690, %rd691, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10400, %r9995, %r9290; + bfe.u32 %r10401, %r10400, 4, 14; + cvt.u64.u32 %rd777, %r10401; + or.b64 %rd692, %rd777, 4611686293372403712; + add.s32 %r10402, %r9293, 64; + bfe.u32 %r10403, %r10402, 4, 14; + cvt.u64.u32 %rd778, %r10403; + or.b64 %rd693, %rd778, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd692, %rd693, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10404, %r10000, %r9290; + bfe.u32 %r10405, %r10404, 4, 14; + cvt.u64.u32 %rd779, %r10405; + or.b64 %rd694, %rd779, 4611686293372403712; + add.s32 %r10406, %r9293, 96; + bfe.u32 %r10407, %r10406, 4, 14; + cvt.u64.u32 %rd780, %r10407; + or.b64 %rd695, %rd780, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd694, %rd695, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10408, %r10005, %r9290; + bfe.u32 %r10409, %r10408, 4, 14; + cvt.u64.u32 %rd781, %r10409; + or.b64 %rd696, %rd781, 4611686293372403712; + add.s32 %r10410, %r9293, 8192; + bfe.u32 %r10411, %r10410, 4, 14; + cvt.u64.u32 %rd782, %r10411; + or.b64 %rd697, %rd782, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd696, %rd697, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10412, %r10010, %r9290; + bfe.u32 %r10413, %r10412, 4, 14; + cvt.u64.u32 %rd783, %r10413; + or.b64 %rd698, %rd783, 4611686293372403712; + add.s32 %r10414, %r9293, 8224; + bfe.u32 %r10415, %r10414, 4, 14; + cvt.u64.u32 %rd784, %r10415; + or.b64 %rd699, %rd784, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd698, %rd699, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10416, %r10015, %r9290; + bfe.u32 %r10417, %r10416, 4, 14; + cvt.u64.u32 %rd785, %r10417; + or.b64 %rd700, %rd785, 4611686293372403712; + add.s32 %r10418, %r9293, 8256; + bfe.u32 %r10419, %r10418, 4, 14; + cvt.u64.u32 %rd786, %r10419; + or.b64 %rd701, %rd786, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd700, %rd701, %p663, 1, 1, 0, 0; + // end inline asm + add.s32 %r10420, %r10020, %r9290; + bfe.u32 %r10421, %r10420, 4, 14; + cvt.u64.u32 %rd787, %r10421; + or.b64 %rd702, %rd787, 4611686293372403712; + add.s32 %r10422, %r9293, 8288; + bfe.u32 %r10423, %r10422, 4, 14; + cvt.u64.u32 %rd788, %r10423; + or.b64 %rd703, %rd788, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905}, %rd702, %rd703, %p663, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r9295, %r8245; + mov.b32 %r9291, %r8245; + mov.b32 %r9292, %r8245; + mov.b32 %r9294, %r8245; + // begin inline asm + // wait for regs: %r8874,%r8875,%r8876,%r8877,%r8878,%r8879,%r8880,%r8881,%r8882,%r8883,%r8884,%r8885,%r8886,%r8887,%r8888,%r8889,%r8890,%r8891,%r8892,%r8893,%r8894,%r8895,%r8896,%r8897,%r8898,%r8899,%r8900,%r8901,%r8902,%r8903,%r8904,%r8905,%r9290,%r9291,%r9292,%r9293,%r9294,%r9295 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 715 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:715:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + sub.f32 %r10424, %r8874, %r10378; + sub.f32 %r10425, %r8875, %r10379; + sub.f32 %r10426, %r8876, %r10378; + sub.f32 %r10427, %r8877, %r10379; + sub.f32 %r10428, %r8878, %r10380; + sub.f32 %r10429, %r8879, %r10381; + sub.f32 %r10430, %r8880, %r10380; + sub.f32 %r10431, %r8881, %r10381; + sub.f32 %r10432, %r8882, %r10382; + sub.f32 %r10433, %r8883, %r10383; + sub.f32 %r10434, %r8884, %r10382; + sub.f32 %r10435, %r8885, %r10383; + sub.f32 %r10436, %r8886, %r10384; + sub.f32 %r10437, %r8887, %r10385; + sub.f32 %r10438, %r8888, %r10384; + sub.f32 %r10439, %r8889, %r10385; + sub.f32 %r10440, %r8890, %r10386; + sub.f32 %r10441, %r8891, %r10387; + sub.f32 %r10442, %r8892, %r10386; + sub.f32 %r10443, %r8893, %r10387; + sub.f32 %r10444, %r8894, %r10388; + sub.f32 %r10445, %r8895, %r10389; + sub.f32 %r10446, %r8896, %r10388; + sub.f32 %r10447, %r8897, %r10389; + sub.f32 %r10448, %r8898, %r10390; + sub.f32 %r10449, %r8899, %r10391; + sub.f32 %r10450, %r8900, %r10390; + sub.f32 %r10451, %r8901, %r10391; + sub.f32 %r10452, %r8902, %r10392; + sub.f32 %r10453, %r8903, %r10393; + sub.f32 %r10454, %r8904, %r10392; + sub.f32 %r10455, %r8905, %r10393; + .loc 1 715 16 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:715:16 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.f32 %r10456, %r10335, %r10424; + mul.f32 %r10457, %r10336, %r10425; + mul.f32 %r10458, %r10337, %r10426; + mul.f32 %r10459, %r10338, %r10427; + mul.f32 %r10460, %r10339, %r10428; + mul.f32 %r10461, %r10340, %r10429; + mul.f32 %r10462, %r10341, %r10430; + mul.f32 %r10463, %r10342, %r10431; + mul.f32 %r10464, %r10343, %r10432; + mul.f32 %r10465, %r10344, %r10433; + mul.f32 %r10466, %r10345, %r10434; + mul.f32 %r10467, %r10346, %r10435; + mul.f32 %r10468, %r10347, %r10436; + mul.f32 %r10469, %r10348, %r10437; + mul.f32 %r10470, %r10349, %r10438; + mul.f32 %r10471, %r10350, %r10439; + mul.f32 %r10472, %r10351, %r10440; + mul.f32 %r10473, %r10352, %r10441; + mul.f32 %r10474, %r10353, %r10442; + mul.f32 %r10475, %r10354, %r10443; + mul.f32 %r10476, %r10355, %r10444; + mul.f32 %r10477, %r10356, %r10445; + mul.f32 %r10478, %r10357, %r10446; + mul.f32 %r10479, %r10358, %r10447; + mul.f32 %r10480, %r10359, %r10448; + mul.f32 %r10481, %r10360, %r10449; + mul.f32 %r10482, %r10361, %r10450; + mul.f32 %r10483, %r10362, %r10451; + mul.f32 %r10484, %r10363, %r10452; + mul.f32 %r10485, %r10364, %r10453; + mul.f32 %r10486, %r10365, %r10454; + mul.f32 %r10487, %r10366, %r10455; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs243, %r10456; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs244, %rs243, 0x0000, %p1008; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs245, %r10457; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs246, %rs245, 0x0000, %p1009; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs247, %r10458; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs248, %rs247, 0x0000, %p1010; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs249, %r10459; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs250, %rs249, 0x0000, %p1011; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs251, %r10460; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs252, %rs251, 0x0000, %p1012; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs253, %r10461; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs254, %rs253, 0x0000, %p1013; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs255, %r10462; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs256, %rs255, 0x0000, %p1014; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs257, %r10463; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs258, %rs257, 0x0000, %p1015; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs259, %r10464; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs260, %rs259, 0x0000, %p1016; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs261, %r10465; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs262, %rs261, 0x0000, %p1017; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs263, %r10466; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs264, %rs263, 0x0000, %p1018; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs265, %r10467; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs266, %rs265, 0x0000, %p1019; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs267, %r10468; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs268, %rs267, 0x0000, %p1020; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs269, %r10469; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs270, %rs269, 0x0000, %p1021; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs271, %r10470; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs272, %rs271, 0x0000, %p1022; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs273, %r10471; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs274, %rs273, 0x0000, %p1023; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs275, %r10472; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs276, %rs275, 0x0000, %p1024; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs277, %r10473; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs278, %rs277, 0x0000, %p1025; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs279, %r10474; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs280, %rs279, 0x0000, %p1026; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs281, %r10475; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs282, %rs281, 0x0000, %p1027; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs283, %r10476; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs284, %rs283, 0x0000, %p1028; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs285, %r10477; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs286, %rs285, 0x0000, %p1029; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs287, %r10478; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs288, %rs287, 0x0000, %p1030; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs289, %r10479; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs290, %rs289, 0x0000, %p1031; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs291, %r10480; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs292, %rs291, 0x0000, %p1032; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs293, %r10481; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs294, %rs293, 0x0000, %p1033; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs295, %r10482; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs296, %rs295, 0x0000, %p1034; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs297, %r10483; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs298, %rs297, 0x0000, %p1035; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs299, %r10484; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs300, %rs299, 0x0000, %p1036; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs301, %r10485; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs302, %rs301, 0x0000, %p1037; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs303, %r10486; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs304, %rs303, 0x0000, %p1038; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + cvt.rn.bf16.f32 %rs305, %r10487; + .loc 1 737 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:737:45 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + selp.b16 %rs306, %rs305, 0x0000, %p1039; + .loc 1 739 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mov.b32 %r9462, {%rs244, %rs246}; + mov.b32 %r9463, {%rs248, %rs250}; + mov.b32 %r9464, {%rs252, %rs254}; + mov.b32 %r9465, {%rs256, %rs258}; + mov.b32 %r9594, {%rs260, %rs262}; + mov.b32 %r9595, {%rs264, %rs266}; + mov.b32 %r9596, {%rs268, %rs270}; + mov.b32 %r9597, {%rs272, %rs274}; + mov.b32 %r9726, {%rs276, %rs278}; + mov.b32 %r9727, {%rs280, %rs282}; + mov.b32 %r9728, {%rs284, %rs286}; + mov.b32 %r9729, {%rs288, %rs290}; + mov.b32 %r9858, {%rs292, %rs294}; + mov.b32 %r9859, {%rs296, %rs298}; + mov.b32 %r9860, {%rs300, %rs302}; + mov.b32 %r9861, {%rs304, %rs306}; + wgmma.fence.sync.aligned; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r9462,%r9463,%r9464,%r9465}, %rd669, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10488, %r8247, 2048; + bfe.u32 %r10489, %r10488, 4, 14; + cvt.u64.u32 %rd789, %r10489; + or.b64 %rd705, %rd789, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r9594,%r9595,%r9596,%r9597}, %rd705, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10490, %r8247, 4096; + bfe.u32 %r10491, %r10490, 4, 14; + cvt.u64.u32 %rd790, %r10491; + or.b64 %rd706, %rd790, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r9726,%r9727,%r9728,%r9729}, %rd706, %p663, 1, 1, 1; + // end inline asm + add.s32 %r10492, %r8247, 6144; + bfe.u32 %r10493, %r10492, 4, 14; + cvt.u64.u32 %rd791, %r10493; + or.b64 %rd707, %rd791, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r9858,%r9859,%r9860,%r9861}, %rd707, %p663, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 610 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:610:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r14791, %r14791, %r14617; + add.s32 %r14792, %r14792, %r14617; + add.s32 %r14793, %r14793, %r14617; + add.s32 %r14794, %r14794, %r14617; + add.s32 %r14795, %r14795, %r14617; + add.s32 %r14796, %r14796, %r14617; + add.s32 %r14797, %r14797, %r14617; + add.s32 %r14798, %r14798, %r14617; + add.s32 %r14799, %r14799, %r14617; + add.s32 %r14800, %r14800, %r14617; + add.s32 %r14801, %r14801, %r14617; + add.s32 %r14802, %r14802, %r14617; + add.s32 %r14803, %r14803, %r14617; + add.s32 %r14804, %r14804, %r14617; + add.s32 %r14805, %r14805, %r14617; + add.s32 %r14806, %r14806, %r14617; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r1468, %r14790, 1; + .loc 1 752 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:752:33 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shr.u32 %r10494, %r1468, 1; + .loc 1 753 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mad.wide.u32 %rd709, %r10494, 4, %rd522; + .loc 1 753 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + mov.u64 %rd708, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd708, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r9862, 0x0; + @%p685 ld.global.L1::evict_last.L2::cache_hint.b32 { %r9862 }, [ %rd709 + 0 ], %rd708; + // end inline asm + .loc 1 754 109 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:109 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10495, %r10494, 1; + .loc 1 754 113 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:113 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p1040, %r10495, %r7366; + .loc 1 754 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:55 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd712, %rd709, 4; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.pred %p686, %p685, %p1040; + .loc 1 754 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + mov.u64 %rd711, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd711, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r9863, 0x0; + @%p686 ld.global.L1::evict_last.L2::cache_hint.b32 { %r9863 }, [ %rd712 + 0 ], %rd711; + // end inline asm + .loc 1 755 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:755:35 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + and.b32 %r10496, %r14790, 1; + .loc 1 756 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:34 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + sub.s32 %r10497, %r9863, %r9862; + .loc 1 756 48 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:48 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10498, %r10497, 7; + .loc 1 756 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10499, %r10498, -64; + .loc 1 757 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + xor.b32 %r10500, %r10496, 1; + .loc 1 757 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10501, %r10496, 6; + .loc 1 757 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mad.lo.s32 %r14617, %r10499, %r10500, %r10501; + .loc 1 608 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10502, %r14617, 12; + .loc 1 608 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.wide.s32 %rd792, %r10502, 2; + add.s64 %rd1128, %rd1128, %rd792; + add.s64 %rd1127, %rd1127, %rd792; + add.s64 %rd1126, %rd1126, %rd792; + add.s64 %rd1125, %rd1125, %rd792; + .loc 1 609 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10503, %r14617, 7; + .loc 1 609 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.wide.s32 %rd793, %r10503, 2; + add.s64 %rd1124, %rd1124, %rd793; + add.s64 %rd1123, %rd1123, %rd793; + add.s64 %rd1122, %rd1122, %rd793; + add.s64 %rd1121, %rd1121, %rd793; + .loc 1 610 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:610:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r1470, %r14617, %r14653; + add.s32 %r1471, %r14617, %r14652; + add.s32 %r1472, %r14617, %r14651; + add.s32 %r1473, %r14617, %r14650; + add.s32 %r1474, %r14617, %r14649; + add.s32 %r1475, %r14617, %r14648; + add.s32 %r1476, %r14617, %r14647; + add.s32 %r1477, %r14617, %r14646; + add.s32 %r1478, %r14617, %r14645; + add.s32 %r1479, %r14617, %r14644; + add.s32 %r1480, %r14617, %r14643; + add.s32 %r1481, %r14617, %r14642; + add.s32 %r1482, %r14617, %r14641; + add.s32 %r1483, %r14617, %r14640; + add.s32 %r1484, %r14617, %r14639; + add.s32 %r1485, %r14617, %r14638; + add.s32 %r14658, %r14617, %r14658; + add.s32 %r14659, %r14617, %r14659; + add.s32 %r14660, %r14617, %r14660; + add.s32 %r14661, %r14617, %r14661; + add.s32 %r14654, %r14617, %r14654; + add.s32 %r14655, %r14617, %r14655; + add.s32 %r14656, %r14617, %r14656; + add.s32 %r14657, %r14617, %r14657; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10504, %r14635, 1; + setp.gt.s32 %p1041, %r10504, 1; + selp.b32 %r14635, 0, %r10504, %p1041; + add.s32 %r10505, %r14637, 1; + setp.gt.s32 %p1042, %r10505, 2; + selp.b32 %r14637, 0, %r10505, %p1042; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p1043, %r14658, %r2338; + setp.lt.s32 %p1044, %r14659, %r2338; + setp.lt.s32 %p1045, %r14660, %r2338; + setp.lt.s32 %p1046, %r14661, %r2338; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10506, %r14637, 14; + add.s32 %r10507, %r7390, %r10506; + bar.sync 0; + add.s32 %r9864, %r10507, %r761; + selp.b32 %r10508, 16, 0, %p1043; + selp.b32 %r9865, %r10508, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9864 + 0 ], [ %rd1128 + 0 ], 0x10, %r9865; + // end inline asm + add.s32 %r9866, %r9864, 2048; + selp.b32 %r10509, 16, 0, %p1044; + selp.b32 %r9867, %r10509, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9866 + 0 ], [ %rd1127 + 0 ], 0x10, %r9867; + // end inline asm + add.s32 %r9868, %r9864, 4096; + selp.b32 %r10510, 16, 0, %p1045; + selp.b32 %r9869, %r10510, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9868 + 0 ], [ %rd1126 + 0 ], 0x10, %r9869; + // end inline asm + add.s32 %r9870, %r9864, 6144; + selp.b32 %r10511, 16, 0, %p1046; + selp.b32 %r9871, %r10511, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9870 + 0 ], [ %rd1125 + 0 ], 0x10, %r9871; + // end inline asm + cp.async.commit_group; + .loc 1 656 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p1047, %r1470, %r2338; + setp.lt.s32 %p1048, %r1471, %r2338; + setp.lt.s32 %p1049, %r1472, %r2338; + setp.lt.s32 %p1050, %r1473, %r2338; + setp.lt.s32 %p1051, %r1474, %r2338; + setp.lt.s32 %p1052, %r1475, %r2338; + setp.lt.s32 %p1053, %r1476, %r2338; + setp.lt.s32 %p1054, %r1477, %r2338; + setp.lt.s32 %p1055, %r1478, %r2338; + setp.lt.s32 %p1056, %r1479, %r2338; + setp.lt.s32 %p1057, %r1480, %r2338; + setp.lt.s32 %p1058, %r1481, %r2338; + setp.lt.s32 %p1059, %r1482, %r2338; + setp.lt.s32 %p1060, %r1483, %r2338; + setp.lt.s32 %p1061, %r1484, %r2338; + setp.lt.s32 %p1062, %r1485, %r2338; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + mul.wide.s32 %rd794, %r1470, 4; + add.s64 %rd718, %rd127, %rd794; + mul.wide.s32 %rd795, %r1471, 4; + add.s64 %rd719, %rd127, %rd795; + mul.wide.s32 %rd796, %r1472, 4; + add.s64 %rd720, %rd127, %rd796; + mul.wide.s32 %rd797, %r1473, 4; + add.s64 %rd721, %rd127, %rd797; + mul.wide.s32 %rd798, %r1474, 4; + add.s64 %rd722, %rd127, %rd798; + mul.wide.s32 %rd799, %r1475, 4; + add.s64 %rd723, %rd127, %rd799; + mul.wide.s32 %rd800, %r1476, 4; + add.s64 %rd724, %rd127, %rd800; + mul.wide.s32 %rd801, %r1477, 4; + add.s64 %rd725, %rd127, %rd801; + mul.wide.s32 %rd802, %r1478, 4; + add.s64 %rd726, %rd127, %rd802; + mul.wide.s32 %rd803, %r1479, 4; + add.s64 %rd727, %rd127, %rd803; + mul.wide.s32 %rd804, %r1480, 4; + add.s64 %rd728, %rd127, %rd804; + mul.wide.s32 %rd805, %r1481, 4; + add.s64 %rd729, %rd127, %rd805; + mul.wide.s32 %rd806, %r1482, 4; + add.s64 %rd730, %rd127, %rd806; + mul.wide.s32 %rd807, %r1483, 4; + add.s64 %rd731, %rd127, %rd807; + mul.wide.s32 %rd808, %r1484, 4; + add.s64 %rd732, %rd127, %rd808; + mul.wide.s32 %rd809, %r1485, 4; + add.s64 %rd733, %rd127, %rd809; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + shl.b32 %r10512, %r14635, 8; + add.s32 %r10513, %r9949, %r10512; + add.s32 %r9872, %r10513, %r787; + selp.b32 %r10514, 4, 0, %p1047; + selp.b32 %r9913, %r10514, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9872 + 0 ], [ %rd718 + 0 ], 0x4, %r9913; + // end inline asm + add.s32 %r9874, %r9872, 4; + selp.b32 %r10515, 4, 0, %p1048; + selp.b32 %r9915, %r10515, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9874 + 0 ], [ %rd719 + 0 ], 0x4, %r9915; + // end inline asm + add.s32 %r9876, %r9872, 32; + selp.b32 %r10516, 4, 0, %p1049; + selp.b32 %r9917, %r10516, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9876 + 0 ], [ %rd720 + 0 ], 0x4, %r9917; + // end inline asm + add.s32 %r9878, %r9872, 36; + selp.b32 %r10517, 4, 0, %p1050; + selp.b32 %r9919, %r10517, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9878 + 0 ], [ %rd721 + 0 ], 0x4, %r9919; + // end inline asm + add.s32 %r9880, %r9872, 64; + selp.b32 %r10518, 4, 0, %p1051; + selp.b32 %r9921, %r10518, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9880 + 0 ], [ %rd722 + 0 ], 0x4, %r9921; + // end inline asm + add.s32 %r9882, %r9872, 68; + selp.b32 %r10519, 4, 0, %p1052; + selp.b32 %r9923, %r10519, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9882 + 0 ], [ %rd723 + 0 ], 0x4, %r9923; + // end inline asm + add.s32 %r9884, %r9872, 96; + selp.b32 %r10520, 4, 0, %p1053; + selp.b32 %r9925, %r10520, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9884 + 0 ], [ %rd724 + 0 ], 0x4, %r9925; + // end inline asm + add.s32 %r9886, %r9872, 100; + selp.b32 %r10521, 4, 0, %p1054; + selp.b32 %r9927, %r10521, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9886 + 0 ], [ %rd725 + 0 ], 0x4, %r9927; + // end inline asm + add.s32 %r9888, %r9872, 128; + selp.b32 %r10522, 4, 0, %p1055; + selp.b32 %r9929, %r10522, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9888 + 0 ], [ %rd726 + 0 ], 0x4, %r9929; + // end inline asm + add.s32 %r9890, %r9872, 132; + selp.b32 %r10523, 4, 0, %p1056; + selp.b32 %r9931, %r10523, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9890 + 0 ], [ %rd727 + 0 ], 0x4, %r9931; + // end inline asm + add.s32 %r9892, %r9872, 160; + selp.b32 %r10524, 4, 0, %p1057; + selp.b32 %r9933, %r10524, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9892 + 0 ], [ %rd728 + 0 ], 0x4, %r9933; + // end inline asm + add.s32 %r9894, %r9872, 164; + selp.b32 %r10525, 4, 0, %p1058; + selp.b32 %r9935, %r10525, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9894 + 0 ], [ %rd729 + 0 ], 0x4, %r9935; + // end inline asm + add.s32 %r9896, %r9872, 192; + selp.b32 %r10526, 4, 0, %p1059; + selp.b32 %r9937, %r10526, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9896 + 0 ], [ %rd730 + 0 ], 0x4, %r9937; + // end inline asm + add.s32 %r9898, %r9872, 196; + selp.b32 %r10527, 4, 0, %p1060; + selp.b32 %r9939, %r10527, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9898 + 0 ], [ %rd731 + 0 ], 0x4, %r9939; + // end inline asm + add.s32 %r9900, %r9872, 224; + selp.b32 %r10528, 4, 0, %p1061; + selp.b32 %r9941, %r10528, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9900 + 0 ], [ %rd732 + 0 ], 0x4, %r9941; + // end inline asm + add.s32 %r9902, %r9872, 228; + selp.b32 %r10529, 4, 0, %p1062; + selp.b32 %r9943, %r10529, 0, %p719; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9902 + 0 ], [ %rd733 + 0 ], 0x4, %r9943; + // end inline asm + cp.async.commit_group; + .loc 1 797 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.lt.s32 %p1063, %r14654, %r2338; + setp.lt.s32 %p1064, %r14655, %r2338; + setp.lt.s32 %p1065, %r14656, %r2338; + setp.lt.s32 %p1066, %r14657, %r2338; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10530, %r10367, %r10506; + add.s32 %r9904, %r10530, %r761; + selp.b32 %r10531, 16, 0, %p1063; + selp.b32 %r9905, %r10531, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9904 + 0 ], [ %rd1124 + 0 ], 0x10, %r9905; + // end inline asm + add.s32 %r9906, %r9904, 2048; + selp.b32 %r10532, 16, 0, %p1064; + selp.b32 %r9907, %r10532, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9906 + 0 ], [ %rd1123 + 0 ], 0x10, %r9907; + // end inline asm + add.s32 %r9908, %r9904, 4096; + selp.b32 %r10533, 16, 0, %p1065; + selp.b32 %r9909, %r10533, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9908 + 0 ], [ %rd1122 + 0 ], 0x10, %r9909; + // end inline asm + add.s32 %r9910, %r9904, 6144; + selp.b32 %r10534, 16, 0, %p1066; + selp.b32 %r9911, %r10534, 0, %p719; + // begin inline asm + cp.async.cg.shared.global [ %r9910 + 0 ], [ %rd1121 + 0 ], 0x10, %r9911; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s64 %rd738, %rd128, %rd794; + add.s64 %rd739, %rd128, %rd795; + add.s64 %rd740, %rd128, %rd796; + add.s64 %rd741, %rd128, %rd797; + add.s64 %rd742, %rd128, %rd798; + add.s64 %rd743, %rd128, %rd799; + add.s64 %rd744, %rd128, %rd800; + add.s64 %rd745, %rd128, %rd801; + add.s64 %rd746, %rd128, %rd802; + add.s64 %rd747, %rd128, %rd803; + add.s64 %rd748, %rd128, %rd804; + add.s64 %rd749, %rd128, %rd805; + add.s64 %rd750, %rd128, %rd806; + add.s64 %rd751, %rd128, %rd807; + add.s64 %rd752, %rd128, %rd808; + add.s64 %rd753, %rd128, %rd809; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + add.s32 %r10535, %r10375, %r10512; + add.s32 %r9912, %r10535, %r787; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9912 + 0 ], [ %rd738 + 0 ], 0x4, %r9913; + // end inline asm + add.s32 %r9914, %r9912, 4; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9914 + 0 ], [ %rd739 + 0 ], 0x4, %r9915; + // end inline asm + add.s32 %r9916, %r9912, 32; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9916 + 0 ], [ %rd740 + 0 ], 0x4, %r9917; + // end inline asm + add.s32 %r9918, %r9912, 36; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9918 + 0 ], [ %rd741 + 0 ], 0x4, %r9919; + // end inline asm + add.s32 %r9920, %r9912, 64; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9920 + 0 ], [ %rd742 + 0 ], 0x4, %r9921; + // end inline asm + add.s32 %r9922, %r9912, 68; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9922 + 0 ], [ %rd743 + 0 ], 0x4, %r9923; + // end inline asm + add.s32 %r9924, %r9912, 96; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9924 + 0 ], [ %rd744 + 0 ], 0x4, %r9925; + // end inline asm + add.s32 %r9926, %r9912, 100; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9926 + 0 ], [ %rd745 + 0 ], 0x4, %r9927; + // end inline asm + add.s32 %r9928, %r9912, 128; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9928 + 0 ], [ %rd746 + 0 ], 0x4, %r9929; + // end inline asm + add.s32 %r9930, %r9912, 132; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9930 + 0 ], [ %rd747 + 0 ], 0x4, %r9931; + // end inline asm + add.s32 %r9932, %r9912, 160; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9932 + 0 ], [ %rd748 + 0 ], 0x4, %r9933; + // end inline asm + add.s32 %r9934, %r9912, 164; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9934 + 0 ], [ %rd749 + 0 ], 0x4, %r9935; + // end inline asm + add.s32 %r9936, %r9912, 192; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9936 + 0 ], [ %rd750 + 0 ], 0x4, %r9937; + // end inline asm + add.s32 %r9938, %r9912, 196; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9938 + 0 ], [ %rd751 + 0 ], 0x4, %r9939; + // end inline asm + add.s32 %r9940, %r9912, 224; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9940 + 0 ], [ %rd752 + 0 ], 0x4, %r9941; + // end inline asm + add.s32 %r9942, %r9912, 228; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r9942 + 0 ], [ %rd753 + 0 ], 0x4, %r9943; + // end inline asm + cp.async.commit_group; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + setp.ne.b32 %p1067, %r984, %r1468; + mov.b32 %r14618, %r14653; + mov.b32 %r14619, %r14652; + mov.b32 %r14620, %r14651; + mov.b32 %r14621, %r14650; + mov.b32 %r14622, %r14649; + mov.b32 %r14623, %r14648; + mov.b32 %r14624, %r14647; + mov.b32 %r14625, %r14646; + mov.b32 %r14626, %r14645; + mov.b32 %r14627, %r14644; + mov.b32 %r14628, %r14643; + mov.b32 %r14629, %r14642; + mov.b32 %r14630, %r14641; + mov.b32 %r14631, %r14640; + mov.b32 %r14632, %r14639; + mov.b32 %r14633, %r14638; + mov.b32 %r14638, %r1485; + mov.b32 %r14639, %r1484; + mov.b32 %r14640, %r1483; + mov.b32 %r14641, %r1482; + mov.b32 %r14642, %r1481; + mov.b32 %r14643, %r1480; + mov.b32 %r14644, %r1479; + mov.b32 %r14645, %r1478; + mov.b32 %r14646, %r1477; + mov.b32 %r14647, %r1476; + mov.b32 %r14648, %r1475; + mov.b32 %r14649, %r1474; + mov.b32 %r14650, %r1473; + mov.b32 %r14651, %r1472; + mov.b32 %r14652, %r1471; + mov.b32 %r14653, %r1470; + mov.b32 %r14790, %r1468; + @%p1067 bra $L__BB0_11; +$L__BB0_12: // %._crit_edge1701 + // in Loop: Header=BB0_9 Depth=1 + .loc 1 0 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:28 + setp.lt.s32 %p1132, %r760, 1; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:298:16 ] + // begin inline asm + // wait for regs: %r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725,%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789 + wgmma.wait_group.sync.aligned 0; + // end inline asm + cp.async.wait_group 0; + bar.sync 0; +$L__tmp16: + .loc 1 583 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd891, %rd125, %rd890; + add.s64 %rd893, %rd125, %rd892; + add.s64 %rd895, %rd125, %rd894; + add.s64 %rd897, %rd125, %rd896; + .loc 1 583 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:583:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd810, %rd891, %rd643; + add.s64 %rd811, %rd893, %rd643; + add.s64 %rd812, %rd895, %rd643; + add.s64 %rd813, %rd897, %rd643; + .loc 1 584 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd900, %rd126, %rd899; + add.s64 %rd902, %rd126, %rd901; + add.s64 %rd904, %rd126, %rd903; + add.s64 %rd906, %rd126, %rd905; + .loc 1 584 51 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:584:51 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd830, %rd900, %rd643; + add.s64 %rd831, %rd902, %rd643; + add.s64 %rd832, %rd904, %rd643; + add.s64 %rd833, %rd906, %rd643; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + cp.async.cg.shared.global [ %r10792 + 0 ], [ %rd810 + 0 ], 0x10, %r10793; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10794 + 0 ], [ %rd811 + 0 ], 0x10, %r10795; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10796 + 0 ], [ %rd812 + 0 ], 0x10, %r10797; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10798 + 0 ], [ %rd813 + 0 ], 0x10, %r10799; + // end inline asm + cp.async.commit_group; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd814, %rd127, %rd907; + cvt.s64.s32 %rd908, %r742; + cvt.u64.u32 %rd909, %r692; + add.s64 %rd910, %rd908, %rd909; + shl.b64 %rd911, %rd910, 2; + add.s64 %rd912, %rd127, %rd911; + add.s64 %rd815, %rd912, 4; + add.s64 %rd816, %rd912, 32; + add.s64 %rd817, %rd912, 36; + add.s64 %rd818, %rd127, %rd913; + add.s64 %rd819, %rd127, %rd914; + add.s64 %rd820, %rd127, %rd915; + add.s64 %rd821, %rd127, %rd916; + add.s64 %rd822, %rd127, %rd917; + add.s64 %rd823, %rd127, %rd918; + add.s64 %rd824, %rd127, %rd919; + add.s64 %rd825, %rd127, %rd920; + add.s64 %rd826, %rd127, %rd921; + add.s64 %rd827, %rd127, %rd922; + add.s64 %rd828, %rd127, %rd923; + add.s64 %rd829, %rd127, %rd924; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10800 + 0 ], [ %rd814 + 0 ], 0x4, %r10801; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10802 + 0 ], [ %rd815 + 0 ], 0x4, %r10803; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10804 + 0 ], [ %rd816 + 0 ], 0x4, %r10805; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10806 + 0 ], [ %rd817 + 0 ], 0x4, %r10807; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10808 + 0 ], [ %rd818 + 0 ], 0x4, %r10809; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10810 + 0 ], [ %rd819 + 0 ], 0x4, %r10811; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10812 + 0 ], [ %rd820 + 0 ], 0x4, %r10813; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10814 + 0 ], [ %rd821 + 0 ], 0x4, %r10815; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10816 + 0 ], [ %rd822 + 0 ], 0x4, %r10817; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10818 + 0 ], [ %rd823 + 0 ], 0x4, %r10819; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10820 + 0 ], [ %rd824 + 0 ], 0x4, %r10821; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10822 + 0 ], [ %rd825 + 0 ], 0x4, %r10823; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10824 + 0 ], [ %rd826 + 0 ], 0x4, %r10825; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10826 + 0 ], [ %rd827 + 0 ], 0x4, %r10827; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10828 + 0 ], [ %rd828 + 0 ], 0x4, %r10829; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10830 + 0 ], [ %rd829 + 0 ], 0x4, %r10831; + // end inline asm + cp.async.commit_group; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + cp.async.cg.shared.global [ %r10832 + 0 ], [ %rd830 + 0 ], 0x10, %r10793; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10834 + 0 ], [ %rd831 + 0 ], 0x10, %r10795; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10836 + 0 ], [ %rd832 + 0 ], 0x10, %r10797; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10838 + 0 ], [ %rd833 + 0 ], 0x10, %r10799; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd834, %rd128, %rd907; + add.s64 %rd925, %rd128, %rd911; + add.s64 %rd835, %rd925, 4; + add.s64 %rd836, %rd925, 32; + add.s64 %rd837, %rd925, 36; + add.s64 %rd838, %rd128, %rd913; + add.s64 %rd839, %rd128, %rd914; + add.s64 %rd840, %rd128, %rd915; + add.s64 %rd841, %rd128, %rd916; + add.s64 %rd842, %rd128, %rd917; + add.s64 %rd843, %rd128, %rd918; + add.s64 %rd844, %rd128, %rd919; + add.s64 %rd845, %rd128, %rd920; + add.s64 %rd846, %rd128, %rd921; + add.s64 %rd847, %rd128, %rd922; + add.s64 %rd848, %rd128, %rd923; + add.s64 %rd849, %rd128, %rd924; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10840 + 0 ], [ %rd834 + 0 ], 0x4, %r10801; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10842 + 0 ], [ %rd835 + 0 ], 0x4, %r10803; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10844 + 0 ], [ %rd836 + 0 ], 0x4, %r10805; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10846 + 0 ], [ %rd837 + 0 ], 0x4, %r10807; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10848 + 0 ], [ %rd838 + 0 ], 0x4, %r10809; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10850 + 0 ], [ %rd839 + 0 ], 0x4, %r10811; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10852 + 0 ], [ %rd840 + 0 ], 0x4, %r10813; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10854 + 0 ], [ %rd841 + 0 ], 0x4, %r10815; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10856 + 0 ], [ %rd842 + 0 ], 0x4, %r10817; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10858 + 0 ], [ %rd843 + 0 ], 0x4, %r10819; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10860 + 0 ], [ %rd844 + 0 ], 0x4, %r10821; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10862 + 0 ], [ %rd845 + 0 ], 0x4, %r10823; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10864 + 0 ], [ %rd846 + 0 ], 0x4, %r10825; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10866 + 0 ], [ %rd847 + 0 ], 0x4, %r10827; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10868 + 0 ], [ %rd848 + 0 ], 0x4, %r10829; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10870 + 0 ], [ %rd849 + 0 ], 0x4, %r10831; + // end inline asm + cp.async.commit_group; + .loc 1 608 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd1136, %rd810, 524288; + add.s64 %rd1135, %rd811, 524288; + add.s64 %rd1134, %rd812, 524288; + add.s64 %rd1133, %rd813, 524288; + .loc 1 609 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd1132, %rd830, 16384; + add.s64 %rd1131, %rd831, 16384; + add.s64 %rd1130, %rd832, 16384; + add.s64 %rd1129, %rd833, 16384; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + bar.sync 0; + // begin inline asm + cp.async.cg.shared.global [ %r10872 + 0 ], [ %rd1136 + 0 ], 0x10, %r10873; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10874 + 0 ], [ %rd1135 + 0 ], 0x10, %r10875; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10876 + 0 ], [ %rd1134 + 0 ], 0x10, %r10877; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10878 + 0 ], [ %rd1133 + 0 ], 0x10, %r10879; + // end inline asm + cp.async.commit_group; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd854, %rd814, 256; + add.s64 %rd926, %rd911, 256; + add.s64 %rd927, %rd127, %rd926; + add.s64 %rd855, %rd927, 4; + add.s64 %rd856, %rd927, 32; + add.s64 %rd857, %rd927, 36; + add.s64 %rd858, %rd818, 256; + add.s64 %rd859, %rd819, 256; + add.s64 %rd860, %rd820, 256; + add.s64 %rd861, %rd821, 256; + add.s64 %rd862, %rd822, 256; + add.s64 %rd863, %rd823, 256; + add.s64 %rd864, %rd824, 256; + add.s64 %rd865, %rd825, 256; + add.s64 %rd866, %rd826, 256; + add.s64 %rd867, %rd827, 256; + add.s64 %rd868, %rd828, 256; + add.s64 %rd869, %rd829, 256; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10880 + 0 ], [ %rd854 + 0 ], 0x4, %r10881; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10882 + 0 ], [ %rd855 + 0 ], 0x4, %r10883; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10884 + 0 ], [ %rd856 + 0 ], 0x4, %r10885; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10886 + 0 ], [ %rd857 + 0 ], 0x4, %r10887; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10888 + 0 ], [ %rd858 + 0 ], 0x4, %r10889; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10890 + 0 ], [ %rd859 + 0 ], 0x4, %r10891; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10892 + 0 ], [ %rd860 + 0 ], 0x4, %r10893; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10894 + 0 ], [ %rd861 + 0 ], 0x4, %r10895; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10896 + 0 ], [ %rd862 + 0 ], 0x4, %r10897; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10898 + 0 ], [ %rd863 + 0 ], 0x4, %r10899; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10900 + 0 ], [ %rd864 + 0 ], 0x4, %r10901; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10902 + 0 ], [ %rd865 + 0 ], 0x4, %r10903; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10904 + 0 ], [ %rd866 + 0 ], 0x4, %r10905; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10906 + 0 ], [ %rd867 + 0 ], 0x4, %r10907; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10908 + 0 ], [ %rd868 + 0 ], 0x4, %r10909; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10910 + 0 ], [ %rd869 + 0 ], 0x4, %r10911; + // end inline asm + cp.async.commit_group; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + cp.async.cg.shared.global [ %r10912 + 0 ], [ %rd1132 + 0 ], 0x10, %r10873; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10914 + 0 ], [ %rd1131 + 0 ], 0x10, %r10875; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10916 + 0 ], [ %rd1130 + 0 ], 0x10, %r10877; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r10918 + 0 ], [ %rd1129 + 0 ], 0x10, %r10879; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd874, %rd834, 256; + add.s64 %rd928, %rd128, %rd926; + add.s64 %rd875, %rd928, 4; + add.s64 %rd876, %rd928, 32; + add.s64 %rd877, %rd928, 36; + add.s64 %rd878, %rd838, 256; + add.s64 %rd879, %rd839, 256; + add.s64 %rd880, %rd840, 256; + add.s64 %rd881, %rd841, 256; + add.s64 %rd882, %rd842, 256; + add.s64 %rd883, %rd843, 256; + add.s64 %rd884, %rd844, 256; + add.s64 %rd885, %rd845, 256; + add.s64 %rd886, %rd846, 256; + add.s64 %rd887, %rd847, 256; + add.s64 %rd888, %rd848, 256; + add.s64 %rd889, %rd849, 256; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10920 + 0 ], [ %rd874 + 0 ], 0x4, %r10881; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10922 + 0 ], [ %rd875 + 0 ], 0x4, %r10883; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10924 + 0 ], [ %rd876 + 0 ], 0x4, %r10885; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10926 + 0 ], [ %rd877 + 0 ], 0x4, %r10887; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10928 + 0 ], [ %rd878 + 0 ], 0x4, %r10889; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10930 + 0 ], [ %rd879 + 0 ], 0x4, %r10891; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10932 + 0 ], [ %rd880 + 0 ], 0x4, %r10893; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10934 + 0 ], [ %rd881 + 0 ], 0x4, %r10895; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10936 + 0 ], [ %rd882 + 0 ], 0x4, %r10897; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10938 + 0 ], [ %rd883 + 0 ], 0x4, %r10899; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10940 + 0 ], [ %rd884 + 0 ], 0x4, %r10901; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10942 + 0 ], [ %rd885 + 0 ], 0x4, %r10903; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10944 + 0 ], [ %rd886 + 0 ], 0x4, %r10905; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10946 + 0 ], [ %rd887 + 0 ], 0x4, %r10907; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10948 + 0 ], [ %rd888 + 0 ], 0x4, %r10909; + // end inline asm + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r10950 + 0 ], [ %rd889 + 0 ], 0x4, %r10911; + // end inline asm + cp.async.commit_group; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + @%p1132 bra $L__BB0_15; +// %bb.13: // %.lr.ph1873.preheader + // in Loop: Header=BB0_9 Depth=1 + .loc 1 0 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:28 + mov.b32 %r11506, 0; + mov.b32 %r14952, 1; + mov.b32 %r14951, -1; + mov.b32 %r14935, %r759; + mov.b32 %r14936, %r758; + mov.b32 %r14937, %r757; + mov.b32 %r14938, %r756; + mov.b32 %r14939, %r755; + mov.b32 %r14940, %r754; + mov.b32 %r14941, %r753; + mov.b32 %r14942, %r752; + mov.b32 %r14943, %r751; + mov.b32 %r14944, %r750; + mov.b32 %r14945, %r749; + mov.b32 %r14946, %r748; + mov.b32 %r14947, %r747; + mov.b32 %r14948, %r746; + mov.b32 %r14949, %r745; + mov.b32 %r14950, %r744; + mov.b32 %r14953, %r14951; + mov.b32 %r14954, %r14952; + mov.b32 %r14955, %r957; + mov.b32 %r14956, %r956; + mov.b32 %r14957, %r955; + mov.b32 %r14958, %r954; + mov.b32 %r14959, %r953; + mov.b32 %r14960, %r952; + mov.b32 %r14961, %r951; + mov.b32 %r14962, %r950; + mov.b32 %r14963, %r949; + mov.b32 %r14964, %r948; + mov.b32 %r14965, %r947; + mov.b32 %r14966, %r946; + mov.b32 %r14967, %r945; + mov.b32 %r14968, %r944; + mov.b32 %r14969, %r943; + mov.b32 %r14970, %r942; + mov.b32 %r14971, %r958; + mov.b32 %r14972, %r959; + mov.b32 %r14973, %r960; + mov.b32 %r14974, %r961; + mov.b32 %r14975, %r958; + mov.b32 %r14976, %r959; + mov.b32 %r14977, %r960; + mov.b32 %r14978, %r961; + mov.b32 %r15107, %r11506; +$L__BB0_14: // %.lr.ph1873 + // Parent Loop BB0_9 Depth=1 + // => This Inner Loop Header: Depth=2 + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1189, %r15107, %r982; + setp.lt.s32 %p1155, %r15107, %r983; + add.s32 %r13205, %r14951, 1; + setp.gt.s32 %p1190, %r13205, 1; + selp.b32 %r14951, 0, %r13205, %p1190; + add.s32 %r13206, %r14953, 1; + setp.gt.s32 %p1191, %r13206, 2; + selp.b32 %r14953, 0, %r13206, %p1191; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1192, %r14950, %r2338; + setp.lt.s32 %p1193, %r14949, %r2338; + setp.lt.s32 %p1194, %r14948, %r2338; + setp.lt.s32 %p1195, %r14947, %r2338; + setp.lt.s32 %p1196, %r14946, %r2338; + setp.lt.s32 %p1197, %r14945, %r2338; + setp.lt.s32 %p1198, %r14944, %r2338; + setp.lt.s32 %p1199, %r14943, %r2338; + setp.lt.s32 %p1200, %r14942, %r2338; + setp.lt.s32 %p1201, %r14941, %r2338; + setp.lt.s32 %p1202, %r14940, %r2338; + setp.lt.s32 %p1203, %r14939, %r2338; + setp.lt.s32 %p1204, %r14938, %r2338; + setp.lt.s32 %p1205, %r14937, %r2338; + setp.lt.s32 %p1206, %r14936, %r2338; + setp.lt.s32 %p1207, %r14935, %r2338; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cp.async.wait_group 4; + bar.sync 0; + shl.b32 %r13207, %r14953, 14; + add.s32 %r11508, %r7390, %r13207; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13209, %r14951, 8; + add.s32 %r13210, %r7390, 98304; + add.s32 %r13211, %r13210, %r13209; + add.s32 %r13212, %r13211, %r787; + ld.shared.v2.b32 {%r13213, %r13214}, [%r13212]; + ld.shared.v2.b32 {%r13215, %r13216}, [%r13212+32]; + ld.shared.v2.b32 {%r13217, %r13218}, [%r13212+64]; + ld.shared.v2.b32 {%r13219, %r13220}, [%r13212+96]; + ld.shared.v2.b32 {%r13221, %r13222}, [%r13212+128]; + ld.shared.v2.b32 {%r13223, %r13224}, [%r13212+160]; + ld.shared.v2.b32 {%r13225, %r13226}, [%r13212+192]; + ld.shared.v2.b32 {%r13227, %r13228}, [%r13212+224]; + .loc 1 657 26 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:657:26 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.eq.f32 %p1208, %r13213, 0fFF800000; + setp.eq.f32 %p1209, %r13214, 0fFF800000; + setp.eq.f32 %p1210, %r13215, 0fFF800000; + setp.eq.f32 %p1211, %r13216, 0fFF800000; + setp.eq.f32 %p1212, %r13217, 0fFF800000; + setp.eq.f32 %p1213, %r13218, 0fFF800000; + setp.eq.f32 %p1214, %r13219, 0fFF800000; + setp.eq.f32 %p1215, %r13220, 0fFF800000; + setp.eq.f32 %p1216, %r13221, 0fFF800000; + setp.eq.f32 %p1217, %r13222, 0fFF800000; + setp.eq.f32 %p1218, %r13223, 0fFF800000; + setp.eq.f32 %p1219, %r13224, 0fFF800000; + setp.eq.f32 %p1220, %r13225, 0fFF800000; + setp.eq.f32 %p1221, %r13226, 0fFF800000; + setp.eq.f32 %p1222, %r13227, 0fFF800000; + setp.eq.f32 %p1223, %r13228, 0fFF800000; + .loc 1 657 46 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:657:46 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13229, 0f00000000, %r13213, %p1208; + selp.f32 %r13230, 0f00000000, %r13214, %p1209; + selp.f32 %r13231, 0f00000000, %r13215, %p1210; + selp.f32 %r13232, 0f00000000, %r13216, %p1211; + selp.f32 %r13233, 0f00000000, %r13217, %p1212; + selp.f32 %r13234, 0f00000000, %r13218, %p1213; + selp.f32 %r13235, 0f00000000, %r13219, %p1214; + selp.f32 %r13236, 0f00000000, %r13220, %p1215; + selp.f32 %r13237, 0f00000000, %r13221, %p1216; + selp.f32 %r13238, 0f00000000, %r13222, %p1217; + selp.f32 %r13239, 0f00000000, %r13223, %p1218; + selp.f32 %r13240, 0f00000000, %r13224, %p1219; + selp.f32 %r13241, 0f00000000, %r13225, %p1220; + selp.f32 %r13242, 0f00000000, %r13226, %p1221; + selp.f32 %r13243, 0f00000000, %r13227, %p1222; + selp.f32 %r13244, 0f00000000, %r13228, %p1223; + .loc 1 658 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:658:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shfl.sync.idx.b32 %r13245, %r10, 0, 31, -1; + wgmma.fence.sync.aligned; + shl.b32 %r13246, %r13245, 11; + and.b32 %r13247, %r13246, 8192; + add.s32 %r11467, %r7390, 99328; + add.s32 %r13248, %r13247, %r11467; + bfe.u32 %r13249, %r13248, 4, 14; + cvt.u64.u32 %rd1015, %r13249; + or.b64 %rd929, %rd1015, 4611686293372403712; + bfe.u32 %r13250, %r11508, 4, 14; + cvt.u64.u32 %rd1016, %r13250; + or.b64 %rd930, %rd1016, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd929, %rd930, 0, 1, 1, 0, 0; + // end inline asm + or.b32 %r13251, %r13247, 32; + add.s32 %r13252, %r13251, %r11467; + bfe.u32 %r13253, %r13252, 4, 14; + cvt.u64.u32 %rd1017, %r13253; + or.b64 %rd931, %rd1017, 4611686293372403712; + add.s32 %r13254, %r11508, 32; + bfe.u32 %r13255, %r13254, 4, 14; + cvt.u64.u32 %rd1018, %r13255; + or.b64 %rd932, %rd1018, 4611686293338849280; + mov.pred %p1133, -1; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd931, %rd932, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13256, %r13247, 64; + add.s32 %r13257, %r13256, %r11467; + bfe.u32 %r13258, %r13257, 4, 14; + cvt.u64.u32 %rd1019, %r13258; + or.b64 %rd933, %rd1019, 4611686293372403712; + add.s32 %r13259, %r11508, 64; + bfe.u32 %r13260, %r13259, 4, 14; + cvt.u64.u32 %rd1020, %r13260; + or.b64 %rd934, %rd1020, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd933, %rd934, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13261, %r13247, 96; + add.s32 %r13262, %r13261, %r11467; + bfe.u32 %r13263, %r13262, 4, 14; + cvt.u64.u32 %rd1021, %r13263; + or.b64 %rd935, %rd1021, 4611686293372403712; + add.s32 %r13264, %r11508, 96; + bfe.u32 %r13265, %r13264, 4, 14; + cvt.u64.u32 %rd1022, %r13265; + or.b64 %rd936, %rd1022, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd935, %rd936, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13266, %r13247, 16384; + add.s32 %r13267, %r13266, %r11467; + bfe.u32 %r13268, %r13267, 4, 14; + cvt.u64.u32 %rd1023, %r13268; + or.b64 %rd937, %rd1023, 4611686293372403712; + add.s32 %r13269, %r11508, 8192; + bfe.u32 %r13270, %r13269, 4, 14; + cvt.u64.u32 %rd1024, %r13270; + or.b64 %rd938, %rd1024, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd937, %rd938, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13271, %r13247, 16416; + add.s32 %r13272, %r13271, %r11467; + bfe.u32 %r13273, %r13272, 4, 14; + cvt.u64.u32 %rd1025, %r13273; + or.b64 %rd939, %rd1025, 4611686293372403712; + add.s32 %r13274, %r11508, 8224; + bfe.u32 %r13275, %r13274, 4, 14; + cvt.u64.u32 %rd1026, %r13275; + or.b64 %rd940, %rd1026, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd939, %rd940, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13276, %r13247, 16448; + add.s32 %r13277, %r13276, %r11467; + bfe.u32 %r13278, %r13277, 4, 14; + cvt.u64.u32 %rd1027, %r13278; + or.b64 %rd941, %rd1027, 4611686293372403712; + add.s32 %r13279, %r11508, 8256; + bfe.u32 %r13280, %r13279, 4, 14; + cvt.u64.u32 %rd1028, %r13280; + or.b64 %rd942, %rd1028, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd941, %rd942, %p1133, 1, 1, 0, 0; + // end inline asm + or.b32 %r13281, %r13247, 16480; + add.s32 %r13282, %r13281, %r11467; + bfe.u32 %r13283, %r13282, 4, 14; + cvt.u64.u32 %rd1029, %r13283; + or.b64 %rd943, %rd1029, 4611686293372403712; + add.s32 %r13284, %r11508, 8288; + bfe.u32 %r13285, %r13284, 4, 14; + cvt.u64.u32 %rd1030, %r13285; + or.b64 %rd944, %rd1030, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082}, %rd943, %rd944, %p1133, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r11468, %r11506; + mov.b32 %r11469, %r11506; + mov.b32 %r11471, %r11506; + mov.b32 %r11472, %r11506; + mov.b32 %r11470, %r11508; + // begin inline asm + // wait for regs: %r11051,%r11052,%r11053,%r11054,%r11055,%r11056,%r11057,%r11058,%r11059,%r11060,%r11061,%r11062,%r11063,%r11064,%r11065,%r11066,%r11067,%r11068,%r11069,%r11070,%r11071,%r11072,%r11073,%r11074,%r11075,%r11076,%r11077,%r11078,%r11079,%r11080,%r11081,%r11082,%r11467,%r11468,%r11469,%r11470,%r11471,%r11472 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 660 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:660:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13286, %r11051, 0f3DB504F3; + mul.f32 %r13287, %r11052, 0f3DB504F3; + mul.f32 %r13288, %r11053, 0f3DB504F3; + mul.f32 %r13289, %r11054, 0f3DB504F3; + mul.f32 %r13290, %r11055, 0f3DB504F3; + mul.f32 %r13291, %r11056, 0f3DB504F3; + mul.f32 %r13292, %r11057, 0f3DB504F3; + mul.f32 %r13293, %r11058, 0f3DB504F3; + mul.f32 %r13294, %r11059, 0f3DB504F3; + mul.f32 %r13295, %r11060, 0f3DB504F3; + mul.f32 %r13296, %r11061, 0f3DB504F3; + mul.f32 %r13297, %r11062, 0f3DB504F3; + mul.f32 %r13298, %r11063, 0f3DB504F3; + mul.f32 %r13299, %r11064, 0f3DB504F3; + mul.f32 %r13300, %r11065, 0f3DB504F3; + mul.f32 %r13301, %r11066, 0f3DB504F3; + mul.f32 %r13302, %r11067, 0f3DB504F3; + mul.f32 %r13303, %r11068, 0f3DB504F3; + mul.f32 %r13304, %r11069, 0f3DB504F3; + mul.f32 %r13305, %r11070, 0f3DB504F3; + mul.f32 %r13306, %r11071, 0f3DB504F3; + mul.f32 %r13307, %r11072, 0f3DB504F3; + mul.f32 %r13308, %r11073, 0f3DB504F3; + mul.f32 %r13309, %r11074, 0f3DB504F3; + mul.f32 %r13310, %r11075, 0f3DB504F3; + mul.f32 %r13311, %r11076, 0f3DB504F3; + mul.f32 %r13312, %r11077, 0f3DB504F3; + mul.f32 %r13313, %r11078, 0f3DB504F3; + mul.f32 %r13314, %r11079, 0f3DB504F3; + mul.f32 %r13315, %r11080, 0f3DB504F3; + mul.f32 %r13316, %r11081, 0f3DB504F3; + mul.f32 %r13317, %r11082, 0f3DB504F3; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13318, %r13286, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13319, %r13318, 0fFF800000, %p1192; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13320, %r13287, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13321, %r13320, 0fFF800000, %p1193; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13322, %r13288, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13323, %r13322, 0fFF800000, %p1192; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13324, %r13289, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13325, %r13324, 0fFF800000, %p1193; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13326, %r13290, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13327, %r13326, 0fFF800000, %p1194; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13328, %r13291, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13329, %r13328, 0fFF800000, %p1195; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13330, %r13292, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13331, %r13330, 0fFF800000, %p1194; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13332, %r13293, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13333, %r13332, 0fFF800000, %p1195; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13334, %r13294, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13335, %r13334, 0fFF800000, %p1196; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13336, %r13295, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13337, %r13336, 0fFF800000, %p1197; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13338, %r13296, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13339, %r13338, 0fFF800000, %p1196; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13340, %r13297, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13341, %r13340, 0fFF800000, %p1197; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13342, %r13298, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13343, %r13342, 0fFF800000, %p1198; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13344, %r13299, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13345, %r13344, 0fFF800000, %p1199; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13346, %r13300, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13347, %r13346, 0fFF800000, %p1198; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13348, %r13301, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13349, %r13348, 0fFF800000, %p1199; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13350, %r13302, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13351, %r13350, 0fFF800000, %p1200; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13352, %r13303, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13353, %r13352, 0fFF800000, %p1201; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13354, %r13304, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13355, %r13354, 0fFF800000, %p1200; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13356, %r13305, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13357, %r13356, 0fFF800000, %p1201; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13358, %r13306, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13359, %r13358, 0fFF800000, %p1202; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13360, %r13307, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13361, %r13360, 0fFF800000, %p1203; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13362, %r13308, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13363, %r13362, 0fFF800000, %p1202; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13364, %r13309, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13365, %r13364, 0fFF800000, %p1203; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13366, %r13310, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13367, %r13366, 0fFF800000, %p1204; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13368, %r13311, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13369, %r13368, 0fFF800000, %p1205; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13370, %r13312, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13371, %r13370, 0fFF800000, %p1204; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13372, %r13313, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13373, %r13372, 0fFF800000, %p1205; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13374, %r13314, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13375, %r13374, 0fFF800000, %p1206; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13376, %r13315, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13377, %r13376, 0fFF800000, %p1207; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13378, %r13316, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13379, %r13378, 0fFF800000, %p1206; + .loc 1 703 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:703:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13380, %r13317, 0f3FB8AA3B; + .loc 1 674 78 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:674:78 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.f32 %r13381, %r13380, 0fFF800000, %p1207; + .loc 1 704 40 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:704:40 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + sub.f32 %r13382, %r13319, %r13229; + sub.f32 %r13383, %r13321, %r13230; + sub.f32 %r13384, %r13323, %r13229; + sub.f32 %r13385, %r13325, %r13230; + sub.f32 %r13386, %r13327, %r13231; + sub.f32 %r13387, %r13329, %r13232; + sub.f32 %r13388, %r13331, %r13231; + sub.f32 %r13389, %r13333, %r13232; + sub.f32 %r13390, %r13335, %r13233; + sub.f32 %r13391, %r13337, %r13234; + sub.f32 %r13392, %r13339, %r13233; + sub.f32 %r13393, %r13341, %r13234; + sub.f32 %r13394, %r13343, %r13235; + sub.f32 %r13395, %r13345, %r13236; + sub.f32 %r13396, %r13347, %r13235; + sub.f32 %r13397, %r13349, %r13236; + sub.f32 %r13398, %r13351, %r13237; + sub.f32 %r13399, %r13353, %r13238; + sub.f32 %r13400, %r13355, %r13237; + sub.f32 %r13401, %r13357, %r13238; + sub.f32 %r13402, %r13359, %r13239; + sub.f32 %r13403, %r13361, %r13240; + sub.f32 %r13404, %r13363, %r13239; + sub.f32 %r13405, %r13365, %r13240; + sub.f32 %r13406, %r13367, %r13241; + sub.f32 %r13407, %r13369, %r13242; + sub.f32 %r13408, %r13371, %r13241; + sub.f32 %r13409, %r13373, %r13242; + sub.f32 %r13410, %r13375, %r13243; + sub.f32 %r13411, %r13377, %r13244; + sub.f32 %r13412, %r13379, %r13243; + sub.f32 %r13413, %r13381, %r13244; + .loc 1 704 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:704:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + ex2.approx.ftz.f32 %r13414, %r13382; + ex2.approx.ftz.f32 %r13415, %r13383; + ex2.approx.ftz.f32 %r13416, %r13384; + ex2.approx.ftz.f32 %r13417, %r13385; + ex2.approx.ftz.f32 %r13418, %r13386; + ex2.approx.ftz.f32 %r13419, %r13387; + ex2.approx.ftz.f32 %r13420, %r13388; + ex2.approx.ftz.f32 %r13421, %r13389; + ex2.approx.ftz.f32 %r13422, %r13390; + ex2.approx.ftz.f32 %r13423, %r13391; + ex2.approx.ftz.f32 %r13424, %r13392; + ex2.approx.ftz.f32 %r13425, %r13393; + ex2.approx.ftz.f32 %r13426, %r13394; + ex2.approx.ftz.f32 %r13427, %r13395; + ex2.approx.ftz.f32 %r13428, %r13396; + ex2.approx.ftz.f32 %r13429, %r13397; + ex2.approx.ftz.f32 %r13430, %r13398; + ex2.approx.ftz.f32 %r13431, %r13399; + ex2.approx.ftz.f32 %r13432, %r13400; + ex2.approx.ftz.f32 %r13433, %r13401; + ex2.approx.ftz.f32 %r13434, %r13402; + ex2.approx.ftz.f32 %r13435, %r13403; + ex2.approx.ftz.f32 %r13436, %r13404; + ex2.approx.ftz.f32 %r13437, %r13405; + ex2.approx.ftz.f32 %r13438, %r13406; + ex2.approx.ftz.f32 %r13439, %r13407; + ex2.approx.ftz.f32 %r13440, %r13408; + ex2.approx.ftz.f32 %r13441, %r13409; + ex2.approx.ftz.f32 %r13442, %r13410; + ex2.approx.ftz.f32 %r13443, %r13411; + ex2.approx.ftz.f32 %r13444, %r13412; + ex2.approx.ftz.f32 %r13445, %r13413; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13446, %r7390, 49152; + add.s32 %r12554, %r13446, %r13207; + .loc 1 708 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:708:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16x2.f32 %r11639, %r13415, %r13414; + cvt.rn.bf16x2.f32 %r11640, %r13417, %r13416; + cvt.rn.bf16x2.f32 %r11641, %r13419, %r13418; + cvt.rn.bf16x2.f32 %r11642, %r13421, %r13420; + cvt.rn.bf16x2.f32 %r11771, %r13423, %r13422; + cvt.rn.bf16x2.f32 %r11772, %r13425, %r13424; + cvt.rn.bf16x2.f32 %r11773, %r13427, %r13426; + cvt.rn.bf16x2.f32 %r11774, %r13429, %r13428; + cvt.rn.bf16x2.f32 %r11903, %r13431, %r13430; + cvt.rn.bf16x2.f32 %r11904, %r13433, %r13432; + cvt.rn.bf16x2.f32 %r11905, %r13435, %r13434; + cvt.rn.bf16x2.f32 %r11906, %r13437, %r13436; + cvt.rn.bf16x2.f32 %r12035, %r13439, %r13438; + cvt.rn.bf16x2.f32 %r12036, %r13441, %r13440; + cvt.rn.bf16x2.f32 %r12037, %r13443, %r13442; + cvt.rn.bf16x2.f32 %r12038, %r13445, %r13444; + .loc 1 708 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:708:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + wgmma.fence.sync.aligned; + bfe.u32 %r13447, %r12554, 4, 14; + cvt.u64.u32 %rd1031, %r13447; + or.b64 %rd945, %rd1031, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r11639,%r11640,%r11641,%r11642}, %rd945, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13448, %r12554, 2048; + bfe.u32 %r13449, %r13448, 4, 14; + cvt.u64.u32 %rd1032, %r13449; + or.b64 %rd946, %rd1032, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r11771,%r11772,%r11773,%r11774}, %rd946, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13450, %r12554, 4096; + bfe.u32 %r13451, %r13450, 4, 14; + cvt.u64.u32 %rd1033, %r13451; + or.b64 %rd947, %rd1033, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r11903,%r11904,%r11905,%r11906}, %rd947, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13452, %r12554, 6144; + bfe.u32 %r13453, %r13452, 4, 14; + cvt.u64.u32 %rd1034, %r13453; + or.b64 %rd948, %rd1034, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14662,%r14663,%r14664,%r14665,%r14666,%r14667,%r14668,%r14669,%r14670,%r14671,%r14672,%r14673,%r14674,%r14675,%r14676,%r14677,%r14678,%r14679,%r14680,%r14681,%r14682,%r14683,%r14684,%r14685,%r14686,%r14687,%r14688,%r14689,%r14690,%r14691,%r14692,%r14693,%r14694,%r14695,%r14696,%r14697,%r14698,%r14699,%r14700,%r14701,%r14702,%r14703,%r14704,%r14705,%r14706,%r14707,%r14708,%r14709,%r14710,%r14711,%r14712,%r14713,%r14714,%r14715,%r14716,%r14717,%r14718,%r14719,%r14720,%r14721,%r14722,%r14723,%r14724,%r14725}, {%r12035,%r12036,%r12037,%r12038}, %rd948, %p1133, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13454, %r7390, 98816; + add.s32 %r13455, %r13454, %r13209; + add.s32 %r13456, %r13455, %r787; + ld.shared.v2.b32 {%r13457, %r13458}, [%r13456]; + ld.shared.v2.b32 {%r13459, %r13460}, [%r13456+32]; + ld.shared.v2.b32 {%r13461, %r13462}, [%r13456+64]; + ld.shared.v2.b32 {%r13463, %r13464}, [%r13456+96]; + ld.shared.v2.b32 {%r13465, %r13466}, [%r13456+128]; + ld.shared.v2.b32 {%r13467, %r13468}, [%r13456+160]; + ld.shared.v2.b32 {%r13469, %r13470}, [%r13456+192]; + ld.shared.v2.b32 {%r13471, %r13472}, [%r13456+224]; + .loc 1 714 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:714:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + wgmma.fence.sync.aligned; + add.s32 %r12551, %r7390, 132096; + add.s32 %r13473, %r13247, %r12551; + bfe.u32 %r13474, %r13473, 4, 14; + cvt.u64.u32 %rd1035, %r13474; + or.b64 %rd949, %rd1035, 4611686293372403712; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd949, %rd945, 0, 1, 1, 0, 0; + // end inline asm + add.s32 %r13475, %r13251, %r12551; + bfe.u32 %r13476, %r13475, 4, 14; + cvt.u64.u32 %rd1036, %r13476; + or.b64 %rd951, %rd1036, 4611686293372403712; + add.s32 %r13477, %r12554, 32; + bfe.u32 %r13478, %r13477, 4, 14; + cvt.u64.u32 %rd1037, %r13478; + or.b64 %rd952, %rd1037, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd951, %rd952, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13479, %r13256, %r12551; + bfe.u32 %r13480, %r13479, 4, 14; + cvt.u64.u32 %rd1038, %r13480; + or.b64 %rd953, %rd1038, 4611686293372403712; + add.s32 %r13481, %r12554, 64; + bfe.u32 %r13482, %r13481, 4, 14; + cvt.u64.u32 %rd1039, %r13482; + or.b64 %rd954, %rd1039, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd953, %rd954, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13483, %r13261, %r12551; + bfe.u32 %r13484, %r13483, 4, 14; + cvt.u64.u32 %rd1040, %r13484; + or.b64 %rd955, %rd1040, 4611686293372403712; + add.s32 %r13485, %r12554, 96; + bfe.u32 %r13486, %r13485, 4, 14; + cvt.u64.u32 %rd1041, %r13486; + or.b64 %rd956, %rd1041, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd955, %rd956, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13487, %r13266, %r12551; + bfe.u32 %r13488, %r13487, 4, 14; + cvt.u64.u32 %rd1042, %r13488; + or.b64 %rd957, %rd1042, 4611686293372403712; + add.s32 %r13489, %r12554, 8192; + bfe.u32 %r13490, %r13489, 4, 14; + cvt.u64.u32 %rd1043, %r13490; + or.b64 %rd958, %rd1043, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd957, %rd958, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13491, %r13271, %r12551; + bfe.u32 %r13492, %r13491, 4, 14; + cvt.u64.u32 %rd1044, %r13492; + or.b64 %rd959, %rd1044, 4611686293372403712; + add.s32 %r13493, %r12554, 8224; + bfe.u32 %r13494, %r13493, 4, 14; + cvt.u64.u32 %rd1045, %r13494; + or.b64 %rd960, %rd1045, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd959, %rd960, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13495, %r13276, %r12551; + bfe.u32 %r13496, %r13495, 4, 14; + cvt.u64.u32 %rd1046, %r13496; + or.b64 %rd961, %rd1046, 4611686293372403712; + add.s32 %r13497, %r12554, 8256; + bfe.u32 %r13498, %r13497, 4, 14; + cvt.u64.u32 %rd1047, %r13498; + or.b64 %rd962, %rd1047, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd961, %rd962, %p1133, 1, 1, 0, 0; + // end inline asm + add.s32 %r13499, %r13281, %r12551; + bfe.u32 %r13500, %r13499, 4, 14; + cvt.u64.u32 %rd1048, %r13500; + or.b64 %rd963, %rd1048, 4611686293372403712; + add.s32 %r13501, %r12554, 8288; + bfe.u32 %r13502, %r13501, 4, 14; + cvt.u64.u32 %rd1049, %r13502; + or.b64 %rd964, %rd1049, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166}, %rd963, %rd964, %p1133, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r12556, %r11506; + mov.b32 %r12552, %r11506; + mov.b32 %r12553, %r11506; + mov.b32 %r12555, %r11506; + // begin inline asm + // wait for regs: %r12135,%r12136,%r12137,%r12138,%r12139,%r12140,%r12141,%r12142,%r12143,%r12144,%r12145,%r12146,%r12147,%r12148,%r12149,%r12150,%r12151,%r12152,%r12153,%r12154,%r12155,%r12156,%r12157,%r12158,%r12159,%r12160,%r12161,%r12162,%r12163,%r12164,%r12165,%r12166,%r12551,%r12552,%r12553,%r12554,%r12555,%r12556 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 715 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:715:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + sub.f32 %r13503, %r12135, %r13457; + sub.f32 %r13504, %r12136, %r13458; + sub.f32 %r13505, %r12137, %r13457; + sub.f32 %r13506, %r12138, %r13458; + sub.f32 %r13507, %r12139, %r13459; + sub.f32 %r13508, %r12140, %r13460; + sub.f32 %r13509, %r12141, %r13459; + sub.f32 %r13510, %r12142, %r13460; + sub.f32 %r13511, %r12143, %r13461; + sub.f32 %r13512, %r12144, %r13462; + sub.f32 %r13513, %r12145, %r13461; + sub.f32 %r13514, %r12146, %r13462; + sub.f32 %r13515, %r12147, %r13463; + sub.f32 %r13516, %r12148, %r13464; + sub.f32 %r13517, %r12149, %r13463; + sub.f32 %r13518, %r12150, %r13464; + sub.f32 %r13519, %r12151, %r13465; + sub.f32 %r13520, %r12152, %r13466; + sub.f32 %r13521, %r12153, %r13465; + sub.f32 %r13522, %r12154, %r13466; + sub.f32 %r13523, %r12155, %r13467; + sub.f32 %r13524, %r12156, %r13468; + sub.f32 %r13525, %r12157, %r13467; + sub.f32 %r13526, %r12158, %r13468; + sub.f32 %r13527, %r12159, %r13469; + sub.f32 %r13528, %r12160, %r13470; + sub.f32 %r13529, %r12161, %r13469; + sub.f32 %r13530, %r12162, %r13470; + sub.f32 %r13531, %r12163, %r13471; + sub.f32 %r13532, %r12164, %r13472; + sub.f32 %r13533, %r12165, %r13471; + sub.f32 %r13534, %r12166, %r13472; + .loc 1 715 16 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:715:16 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.f32 %r13535, %r13414, %r13503; + mul.f32 %r13536, %r13415, %r13504; + mul.f32 %r13537, %r13416, %r13505; + mul.f32 %r13538, %r13417, %r13506; + mul.f32 %r13539, %r13418, %r13507; + mul.f32 %r13540, %r13419, %r13508; + mul.f32 %r13541, %r13420, %r13509; + mul.f32 %r13542, %r13421, %r13510; + mul.f32 %r13543, %r13422, %r13511; + mul.f32 %r13544, %r13423, %r13512; + mul.f32 %r13545, %r13424, %r13513; + mul.f32 %r13546, %r13425, %r13514; + mul.f32 %r13547, %r13426, %r13515; + mul.f32 %r13548, %r13427, %r13516; + mul.f32 %r13549, %r13428, %r13517; + mul.f32 %r13550, %r13429, %r13518; + mul.f32 %r13551, %r13430, %r13519; + mul.f32 %r13552, %r13431, %r13520; + mul.f32 %r13553, %r13432, %r13521; + mul.f32 %r13554, %r13433, %r13522; + mul.f32 %r13555, %r13434, %r13523; + mul.f32 %r13556, %r13435, %r13524; + mul.f32 %r13557, %r13436, %r13525; + mul.f32 %r13558, %r13437, %r13526; + mul.f32 %r13559, %r13438, %r13527; + mul.f32 %r13560, %r13439, %r13528; + mul.f32 %r13561, %r13440, %r13529; + mul.f32 %r13562, %r13441, %r13530; + mul.f32 %r13563, %r13442, %r13531; + mul.f32 %r13564, %r13443, %r13532; + mul.f32 %r13565, %r13444, %r13533; + mul.f32 %r13566, %r13445, %r13534; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs307, %r13535; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs308, %rs307, 0x0000, %p1192; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs309, %r13536; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs310, %rs309, 0x0000, %p1193; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs311, %r13537; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs312, %rs311, 0x0000, %p1192; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs313, %r13538; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs314, %rs313, 0x0000, %p1193; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs315, %r13539; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs316, %rs315, 0x0000, %p1194; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs317, %r13540; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs318, %rs317, 0x0000, %p1195; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs319, %r13541; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs320, %rs319, 0x0000, %p1194; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs321, %r13542; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs322, %rs321, 0x0000, %p1195; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs323, %r13543; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs324, %rs323, 0x0000, %p1196; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs325, %r13544; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs326, %rs325, 0x0000, %p1197; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs327, %r13545; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs328, %rs327, 0x0000, %p1196; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs329, %r13546; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs330, %rs329, 0x0000, %p1197; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs331, %r13547; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs332, %rs331, 0x0000, %p1198; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs333, %r13548; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs334, %rs333, 0x0000, %p1199; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs335, %r13549; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs336, %rs335, 0x0000, %p1198; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs337, %r13550; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs338, %rs337, 0x0000, %p1199; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs339, %r13551; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs340, %rs339, 0x0000, %p1200; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs341, %r13552; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs342, %rs341, 0x0000, %p1201; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs343, %r13553; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs344, %rs343, 0x0000, %p1200; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs345, %r13554; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs346, %rs345, 0x0000, %p1201; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs347, %r13555; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs348, %rs347, 0x0000, %p1202; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs349, %r13556; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs350, %rs349, 0x0000, %p1203; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs351, %r13557; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs352, %rs351, 0x0000, %p1202; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs353, %r13558; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs354, %rs353, 0x0000, %p1203; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs355, %r13559; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs356, %rs355, 0x0000, %p1204; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs357, %r13560; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs358, %rs357, 0x0000, %p1205; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs359, %r13561; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs360, %rs359, 0x0000, %p1204; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs361, %r13562; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs362, %rs361, 0x0000, %p1205; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs363, %r13563; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs364, %rs363, 0x0000, %p1206; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs365, %r13564; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs366, %rs365, 0x0000, %p1207; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs367, %r13565; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs368, %rs367, 0x0000, %p1206; + .loc 1 739 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + cvt.rn.bf16.f32 %rs369, %r13566; + .loc 1 723 70 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:723:70 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + selp.b16 %rs370, %rs369, 0x0000, %p1207; + .loc 1 739 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:739:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mov.b32 %r12723, {%rs308, %rs310}; + mov.b32 %r12724, {%rs312, %rs314}; + mov.b32 %r12725, {%rs316, %rs318}; + mov.b32 %r12726, {%rs320, %rs322}; + mov.b32 %r12855, {%rs324, %rs326}; + mov.b32 %r12856, {%rs328, %rs330}; + mov.b32 %r12857, {%rs332, %rs334}; + mov.b32 %r12858, {%rs336, %rs338}; + mov.b32 %r12987, {%rs340, %rs342}; + mov.b32 %r12988, {%rs344, %rs346}; + mov.b32 %r12989, {%rs348, %rs350}; + mov.b32 %r12990, {%rs352, %rs354}; + mov.b32 %r13119, {%rs356, %rs358}; + mov.b32 %r13120, {%rs360, %rs362}; + mov.b32 %r13121, {%rs364, %rs366}; + mov.b32 %r13122, {%rs368, %rs370}; + wgmma.fence.sync.aligned; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r12723,%r12724,%r12725,%r12726}, %rd930, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13567, %r11508, 2048; + bfe.u32 %r13568, %r13567, 4, 14; + cvt.u64.u32 %rd1050, %r13568; + or.b64 %rd966, %rd1050, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r12855,%r12856,%r12857,%r12858}, %rd966, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13569, %r11508, 4096; + bfe.u32 %r13570, %r13569, 4, 14; + cvt.u64.u32 %rd1051, %r13570; + or.b64 %rd967, %rd1051, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r12987,%r12988,%r12989,%r12990}, %rd967, %p1133, 1, 1, 1; + // end inline asm + add.s32 %r13571, %r11508, 6144; + bfe.u32 %r13572, %r13571, 4, 14; + cvt.u64.u32 %rd1052, %r13572; + or.b64 %rd968, %rd1052, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14726,%r14727,%r14728,%r14729,%r14730,%r14731,%r14732,%r14733,%r14734,%r14735,%r14736,%r14737,%r14738,%r14739,%r14740,%r14741,%r14742,%r14743,%r14744,%r14745,%r14746,%r14747,%r14748,%r14749,%r14750,%r14751,%r14752,%r14753,%r14754,%r14755,%r14756,%r14757,%r14758,%r14759,%r14760,%r14761,%r14762,%r14763,%r14764,%r14765,%r14766,%r14767,%r14768,%r14769,%r14770,%r14771,%r14772,%r14773,%r14774,%r14775,%r14776,%r14777,%r14778,%r14779,%r14780,%r14781,%r14782,%r14783,%r14784,%r14785,%r14786,%r14787,%r14788,%r14789}, {%r13119,%r13120,%r13121,%r13122}, %rd968, %p1133, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r2055, %r15107, 1; + .loc 1 752 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:752:33 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shr.u32 %r13573, %r2055, 1; + .loc 1 753 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mad.wide.u32 %rd970, %r13573, 4, %rd524; + .loc 1 753 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + mov.u64 %rd969, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd969, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r13123, 0x0; + @%p1155 ld.global.L1::evict_last.L2::cache_hint.b32 { %r13123 }, [ %rd970 + 0 ], %rd969; + // end inline asm + .loc 1 754 109 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:109 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13574, %r13573, 1; + .loc 1 754 113 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:113 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1224, %r13574, %r7368; + .loc 1 754 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:55 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd973, %rd970, 4; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + and.pred %p1156, %p1155, %p1224; + .loc 1 754 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + // begin inline asm + mov.u64 %rd972, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd972, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r13124, 0x0; + @%p1156 ld.global.L1::evict_last.L2::cache_hint.b32 { %r13124 }, [ %rd973 + 0 ], %rd972; + // end inline asm + .loc 1 755 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:755:35 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + and.b32 %r13575, %r15107, 1; + .loc 1 756 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:34 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + sub.s32 %r13576, %r13124, %r13123; + .loc 1 756 48 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:48 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13577, %r13576, 7; + .loc 1 756 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13578, %r13577, -64; + .loc 1 757 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + xor.b32 %r13579, %r13575, 1; + .loc 1 757 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13580, %r13575, 6; + .loc 1 757 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mad.lo.s32 %r13581, %r13578, %r13579, %r13580; + .loc 1 608 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13582, %r13581, 12; + .loc 1 608 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:608:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.wide.s32 %rd1053, %r13582, 2; + add.s64 %rd1136, %rd1136, %rd1053; + add.s64 %rd1135, %rd1135, %rd1053; + add.s64 %rd1134, %rd1134, %rd1053; + add.s64 %rd1133, %rd1133, %rd1053; + .loc 1 609 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13583, %r13581, 7; + .loc 1 609 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:609:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.wide.s32 %rd1054, %r13583, 2; + add.s64 %rd1132, %rd1132, %rd1054; + add.s64 %rd1131, %rd1131, %rd1054; + add.s64 %rd1130, %rd1130, %rd1054; + add.s64 %rd1129, %rd1129, %rd1054; + .loc 1 610 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:610:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r2056, %r13581, %r14970; + add.s32 %r2057, %r13581, %r14969; + add.s32 %r2058, %r13581, %r14968; + add.s32 %r2059, %r13581, %r14967; + add.s32 %r2060, %r13581, %r14966; + add.s32 %r2061, %r13581, %r14965; + add.s32 %r2062, %r13581, %r14964; + add.s32 %r2063, %r13581, %r14963; + add.s32 %r2064, %r13581, %r14962; + add.s32 %r2065, %r13581, %r14961; + add.s32 %r2066, %r13581, %r14960; + add.s32 %r2067, %r13581, %r14959; + add.s32 %r2068, %r13581, %r14958; + add.s32 %r2069, %r13581, %r14957; + add.s32 %r2070, %r13581, %r14956; + add.s32 %r2071, %r13581, %r14955; + add.s32 %r14975, %r13581, %r14975; + add.s32 %r14976, %r13581, %r14976; + add.s32 %r14977, %r13581, %r14977; + add.s32 %r14978, %r13581, %r14978; + add.s32 %r14971, %r13581, %r14971; + add.s32 %r14972, %r13581, %r14972; + add.s32 %r14973, %r13581, %r14973; + add.s32 %r14974, %r13581, %r14974; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13584, %r14952, 1; + setp.gt.s32 %p1225, %r13584, 1; + selp.b32 %r14952, 0, %r13584, %p1225; + add.s32 %r13585, %r14954, 1; + setp.gt.s32 %p1226, %r13585, 2; + selp.b32 %r14954, 0, %r13585, %p1226; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1227, %r14975, %r2338; + setp.lt.s32 %p1228, %r14976, %r2338; + setp.lt.s32 %p1229, %r14977, %r2338; + setp.lt.s32 %p1230, %r14978, %r2338; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13586, %r14954, 14; + add.s32 %r13587, %r7390, %r13586; + bar.sync 0; + add.s32 %r13125, %r13587, %r761; + selp.b32 %r13588, 16, 0, %p1227; + selp.b32 %r13126, %r13588, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13125 + 0 ], [ %rd1136 + 0 ], 0x10, %r13126; + // end inline asm + add.s32 %r13127, %r13125, 2048; + selp.b32 %r13589, 16, 0, %p1228; + selp.b32 %r13128, %r13589, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13127 + 0 ], [ %rd1135 + 0 ], 0x10, %r13128; + // end inline asm + add.s32 %r13129, %r13125, 4096; + selp.b32 %r13590, 16, 0, %p1229; + selp.b32 %r13130, %r13590, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13129 + 0 ], [ %rd1134 + 0 ], 0x10, %r13130; + // end inline asm + add.s32 %r13131, %r13125, 6144; + selp.b32 %r13591, 16, 0, %p1230; + selp.b32 %r13132, %r13591, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13131 + 0 ], [ %rd1133 + 0 ], 0x10, %r13132; + // end inline asm + cp.async.commit_group; + .loc 1 656 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1231, %r2056, %r2338; + setp.lt.s32 %p1232, %r2057, %r2338; + setp.lt.s32 %p1233, %r2058, %r2338; + setp.lt.s32 %p1234, %r2059, %r2338; + setp.lt.s32 %p1235, %r2060, %r2338; + setp.lt.s32 %p1236, %r2061, %r2338; + setp.lt.s32 %p1237, %r2062, %r2338; + setp.lt.s32 %p1238, %r2063, %r2338; + setp.lt.s32 %p1239, %r2064, %r2338; + setp.lt.s32 %p1240, %r2065, %r2338; + setp.lt.s32 %p1241, %r2066, %r2338; + setp.lt.s32 %p1242, %r2067, %r2338; + setp.lt.s32 %p1243, %r2068, %r2338; + setp.lt.s32 %p1244, %r2069, %r2338; + setp.lt.s32 %p1245, %r2070, %r2338; + setp.lt.s32 %p1246, %r2071, %r2338; + .loc 1 656 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + mul.wide.s32 %rd1055, %r2056, 4; + add.s64 %rd979, %rd127, %rd1055; + mul.wide.s32 %rd1056, %r2057, 4; + add.s64 %rd980, %rd127, %rd1056; + mul.wide.s32 %rd1057, %r2058, 4; + add.s64 %rd981, %rd127, %rd1057; + mul.wide.s32 %rd1058, %r2059, 4; + add.s64 %rd982, %rd127, %rd1058; + mul.wide.s32 %rd1059, %r2060, 4; + add.s64 %rd983, %rd127, %rd1059; + mul.wide.s32 %rd1060, %r2061, 4; + add.s64 %rd984, %rd127, %rd1060; + mul.wide.s32 %rd1061, %r2062, 4; + add.s64 %rd985, %rd127, %rd1061; + mul.wide.s32 %rd1062, %r2063, 4; + add.s64 %rd986, %rd127, %rd1062; + mul.wide.s32 %rd1063, %r2064, 4; + add.s64 %rd987, %rd127, %rd1063; + mul.wide.s32 %rd1064, %r2065, 4; + add.s64 %rd988, %rd127, %rd1064; + mul.wide.s32 %rd1065, %r2066, 4; + add.s64 %rd989, %rd127, %rd1065; + mul.wide.s32 %rd1066, %r2067, 4; + add.s64 %rd990, %rd127, %rd1066; + mul.wide.s32 %rd1067, %r2068, 4; + add.s64 %rd991, %rd127, %rd1067; + mul.wide.s32 %rd1068, %r2069, 4; + add.s64 %rd992, %rd127, %rd1068; + mul.wide.s32 %rd1069, %r2070, 4; + add.s64 %rd993, %rd127, %rd1069; + mul.wide.s32 %rd1070, %r2071, 4; + add.s64 %rd994, %rd127, %rd1070; + .loc 1 656 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:656:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + shl.b32 %r13592, %r14952, 8; + add.s32 %r13593, %r13210, %r13592; + add.s32 %r13133, %r13593, %r787; + selp.b32 %r13594, 4, 0, %p1231; + selp.b32 %r13174, %r13594, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13133 + 0 ], [ %rd979 + 0 ], 0x4, %r13174; + // end inline asm + add.s32 %r13135, %r13133, 4; + selp.b32 %r13595, 4, 0, %p1232; + selp.b32 %r13176, %r13595, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13135 + 0 ], [ %rd980 + 0 ], 0x4, %r13176; + // end inline asm + add.s32 %r13137, %r13133, 32; + selp.b32 %r13596, 4, 0, %p1233; + selp.b32 %r13178, %r13596, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13137 + 0 ], [ %rd981 + 0 ], 0x4, %r13178; + // end inline asm + add.s32 %r13139, %r13133, 36; + selp.b32 %r13597, 4, 0, %p1234; + selp.b32 %r13180, %r13597, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13139 + 0 ], [ %rd982 + 0 ], 0x4, %r13180; + // end inline asm + add.s32 %r13141, %r13133, 64; + selp.b32 %r13598, 4, 0, %p1235; + selp.b32 %r13182, %r13598, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13141 + 0 ], [ %rd983 + 0 ], 0x4, %r13182; + // end inline asm + add.s32 %r13143, %r13133, 68; + selp.b32 %r13599, 4, 0, %p1236; + selp.b32 %r13184, %r13599, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13143 + 0 ], [ %rd984 + 0 ], 0x4, %r13184; + // end inline asm + add.s32 %r13145, %r13133, 96; + selp.b32 %r13600, 4, 0, %p1237; + selp.b32 %r13186, %r13600, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13145 + 0 ], [ %rd985 + 0 ], 0x4, %r13186; + // end inline asm + add.s32 %r13147, %r13133, 100; + selp.b32 %r13601, 4, 0, %p1238; + selp.b32 %r13188, %r13601, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13147 + 0 ], [ %rd986 + 0 ], 0x4, %r13188; + // end inline asm + add.s32 %r13149, %r13133, 128; + selp.b32 %r13602, 4, 0, %p1239; + selp.b32 %r13190, %r13602, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13149 + 0 ], [ %rd987 + 0 ], 0x4, %r13190; + // end inline asm + add.s32 %r13151, %r13133, 132; + selp.b32 %r13603, 4, 0, %p1240; + selp.b32 %r13192, %r13603, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13151 + 0 ], [ %rd988 + 0 ], 0x4, %r13192; + // end inline asm + add.s32 %r13153, %r13133, 160; + selp.b32 %r13604, 4, 0, %p1241; + selp.b32 %r13194, %r13604, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13153 + 0 ], [ %rd989 + 0 ], 0x4, %r13194; + // end inline asm + add.s32 %r13155, %r13133, 164; + selp.b32 %r13605, 4, 0, %p1242; + selp.b32 %r13196, %r13605, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13155 + 0 ], [ %rd990 + 0 ], 0x4, %r13196; + // end inline asm + add.s32 %r13157, %r13133, 192; + selp.b32 %r13606, 4, 0, %p1243; + selp.b32 %r13198, %r13606, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13157 + 0 ], [ %rd991 + 0 ], 0x4, %r13198; + // end inline asm + add.s32 %r13159, %r13133, 196; + selp.b32 %r13607, 4, 0, %p1244; + selp.b32 %r13200, %r13607, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13159 + 0 ], [ %rd992 + 0 ], 0x4, %r13200; + // end inline asm + add.s32 %r13161, %r13133, 224; + selp.b32 %r13608, 4, 0, %p1245; + selp.b32 %r13202, %r13608, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13161 + 0 ], [ %rd993 + 0 ], 0x4, %r13202; + // end inline asm + add.s32 %r13163, %r13133, 228; + selp.b32 %r13609, 4, 0, %p1246; + selp.b32 %r13204, %r13609, 0, %p1189; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13163 + 0 ], [ %rd994 + 0 ], 0x4, %r13204; + // end inline asm + cp.async.commit_group; + .loc 1 797 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.lt.s32 %p1247, %r14971, %r2338; + setp.lt.s32 %p1248, %r14972, %r2338; + setp.lt.s32 %p1249, %r14973, %r2338; + setp.lt.s32 %p1250, %r14974, %r2338; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13610, %r13446, %r13586; + add.s32 %r13165, %r13610, %r761; + selp.b32 %r13611, 16, 0, %p1247; + selp.b32 %r13166, %r13611, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13165 + 0 ], [ %rd1132 + 0 ], 0x10, %r13166; + // end inline asm + add.s32 %r13167, %r13165, 2048; + selp.b32 %r13612, 16, 0, %p1248; + selp.b32 %r13168, %r13612, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13167 + 0 ], [ %rd1131 + 0 ], 0x10, %r13168; + // end inline asm + add.s32 %r13169, %r13165, 4096; + selp.b32 %r13613, 16, 0, %p1249; + selp.b32 %r13170, %r13613, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13169 + 0 ], [ %rd1130 + 0 ], 0x10, %r13170; + // end inline asm + add.s32 %r13171, %r13165, 6144; + selp.b32 %r13614, 16, 0, %p1250; + selp.b32 %r13172, %r13614, 0, %p1189; + // begin inline asm + cp.async.cg.shared.global [ %r13171 + 0 ], [ %rd1129 + 0 ], 0x10, %r13172; + // end inline asm + cp.async.commit_group; + .loc 1 712 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s64 %rd999, %rd128, %rd1055; + add.s64 %rd1000, %rd128, %rd1056; + add.s64 %rd1001, %rd128, %rd1057; + add.s64 %rd1002, %rd128, %rd1058; + add.s64 %rd1003, %rd128, %rd1059; + add.s64 %rd1004, %rd128, %rd1060; + add.s64 %rd1005, %rd128, %rd1061; + add.s64 %rd1006, %rd128, %rd1062; + add.s64 %rd1007, %rd128, %rd1063; + add.s64 %rd1008, %rd128, %rd1064; + add.s64 %rd1009, %rd128, %rd1065; + add.s64 %rd1010, %rd128, %rd1066; + add.s64 %rd1011, %rd128, %rd1067; + add.s64 %rd1012, %rd128, %rd1068; + add.s64 %rd1013, %rd128, %rd1069; + add.s64 %rd1014, %rd128, %rd1070; + .loc 1 712 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:712:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + add.s32 %r13615, %r13454, %r13592; + add.s32 %r13173, %r13615, %r787; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13173 + 0 ], [ %rd999 + 0 ], 0x4, %r13174; + // end inline asm + add.s32 %r13175, %r13173, 4; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13175 + 0 ], [ %rd1000 + 0 ], 0x4, %r13176; + // end inline asm + add.s32 %r13177, %r13173, 32; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13177 + 0 ], [ %rd1001 + 0 ], 0x4, %r13178; + // end inline asm + add.s32 %r13179, %r13173, 36; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13179 + 0 ], [ %rd1002 + 0 ], 0x4, %r13180; + // end inline asm + add.s32 %r13181, %r13173, 64; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13181 + 0 ], [ %rd1003 + 0 ], 0x4, %r13182; + // end inline asm + add.s32 %r13183, %r13173, 68; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13183 + 0 ], [ %rd1004 + 0 ], 0x4, %r13184; + // end inline asm + add.s32 %r13185, %r13173, 96; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13185 + 0 ], [ %rd1005 + 0 ], 0x4, %r13186; + // end inline asm + add.s32 %r13187, %r13173, 100; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13187 + 0 ], [ %rd1006 + 0 ], 0x4, %r13188; + // end inline asm + add.s32 %r13189, %r13173, 128; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13189 + 0 ], [ %rd1007 + 0 ], 0x4, %r13190; + // end inline asm + add.s32 %r13191, %r13173, 132; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13191 + 0 ], [ %rd1008 + 0 ], 0x4, %r13192; + // end inline asm + add.s32 %r13193, %r13173, 160; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13193 + 0 ], [ %rd1009 + 0 ], 0x4, %r13194; + // end inline asm + add.s32 %r13195, %r13173, 164; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13195 + 0 ], [ %rd1010 + 0 ], 0x4, %r13196; + // end inline asm + add.s32 %r13197, %r13173, 192; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13197 + 0 ], [ %rd1011 + 0 ], 0x4, %r13198; + // end inline asm + add.s32 %r13199, %r13173, 196; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13199 + 0 ], [ %rd1012 + 0 ], 0x4, %r13200; + // end inline asm + add.s32 %r13201, %r13173, 224; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13201 + 0 ], [ %rd1013 + 0 ], 0x4, %r13202; + // end inline asm + add.s32 %r13203, %r13173, 228; + // begin inline asm + @%p1068 cp.async.ca.shared.global [ %r13203 + 0 ], [ %rd1014 + 0 ], 0x4, %r13204; + // end inline asm + cp.async.commit_group; + .loc 1 592 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:592:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:318:20 ] + setp.ne.b32 %p1251, %r985, %r2055; + mov.b32 %r14935, %r14955; + mov.b32 %r14936, %r14956; + mov.b32 %r14937, %r14957; + mov.b32 %r14938, %r14958; + mov.b32 %r14939, %r14959; + mov.b32 %r14940, %r14960; + mov.b32 %r14941, %r14961; + mov.b32 %r14942, %r14962; + mov.b32 %r14943, %r14963; + mov.b32 %r14944, %r14964; + mov.b32 %r14945, %r14965; + mov.b32 %r14946, %r14966; + mov.b32 %r14947, %r14967; + mov.b32 %r14948, %r14968; + mov.b32 %r14949, %r14969; + mov.b32 %r14950, %r14970; + mov.b32 %r14955, %r2071; + mov.b32 %r14956, %r2070; + mov.b32 %r14957, %r2069; + mov.b32 %r14958, %r2068; + mov.b32 %r14959, %r2067; + mov.b32 %r14960, %r2066; + mov.b32 %r14961, %r2065; + mov.b32 %r14962, %r2064; + mov.b32 %r14963, %r2063; + mov.b32 %r14964, %r2062; + mov.b32 %r14965, %r2061; + mov.b32 %r14966, %r2060; + mov.b32 %r14967, %r2059; + mov.b32 %r14968, %r2058; + mov.b32 %r14969, %r2057; + mov.b32 %r14970, %r2056; + mov.b32 %r15107, %r2055; + @%p1251 bra $L__BB0_14; + bra.uni $L__BB0_15; +$L__tmp17: +$L__BB0_1: + .loc 1 0 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:28 + ld.param.b64 %rd188, [triton_tem_fused_mul_1_param_13]; + ld.param.b64 %rd187, [triton_tem_fused_mul_1_param_12]; + ld.param.b64 %rd184, [triton_tem_fused_mul_1_param_9]; + ld.param.b64 %rd183, [triton_tem_fused_mul_1_param_8]; + ld.param.b64 %rd182, [triton_tem_fused_mul_1_param_6]; +$L__tmp18: + .loc 2 41 22 // standard.py:41:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:113:34 ] + add.s32 %r2520, %r2338, 127; + .loc 2 41 28 // standard.py:41:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:113:34 ] + shr.s32 %r2521, %r2520, 31; + shr.u32 %r2522, %r2521, 25; + add.s32 %r2523, %r2520, %r2522; + shr.s32 %r2524, %r2523, 7; +$L__tmp19: + .loc 1 140 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:140:24 + sub.s32 %r2525, %r4, %r5; + .loc 1 144 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:144:29 + div.s32 %r2527, %r2525, %r2524; + .loc 1 144 54 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:144:54 + shl.b32 %r2528, %r7, 2; + .loc 1 144 44 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:144:44 + add.s32 %r2529, %r2527, %r2528; + .loc 1 145 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:145:35 + mul.lo.s32 %r2530, %r2527, %r2524; + sub.s32 %r2531, %r2525, %r2530; + .loc 1 158 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:158:30 + shl.b32 %r2532, %r2529, 7; + .loc 1 158 40 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:158:40 + mad.lo.s32 %r2533, %r1, %r6, %r2532; + .loc 1 159 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:159:55 + mul.lo.s32 %r2534, %r2, %r6; + .loc 1 159 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:159:42 + mad.lo.s32 %r2535, %r2529, %r3, %r2534; + .loc 1 161 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:161:30 + shl.b32 %r2536, %r6, 5; + .loc 1 161 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:161:35 + add.s32 %r2537, %r2529, %r2536; + .loc 1 161 46 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:161:46 + mul.lo.s32 %r2538, %r2537, %r2338; + .loc 1 163 17 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:163:17 + mul.wide.s32 %rd234, %r2533, 2; + add.s64 %rd235, %rd178, %rd234; + .loc 1 164 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:164:19 + mad.wide.s32 %rd236, %r2535, 2, %rd181; + .loc 1 168 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:168:21 + mul.wide.s32 %rd237, %r2538, 4; + add.s64 %rd238, %rd179, %rd237; + .loc 1 169 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:169:25 + add.s64 %rd239, %rd180, %rd237; + .loc 1 174 36 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:174:36 + shl.b32 %r2539, %r2531, 7; + .loc 1 175 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:175:29 + or.b32 %r22, %r2539, %r12; + or.b32 %r23, %r2539, %r13; + or.b32 %r24, %r2539, %r14; + or.b32 %r25, %r2539, %r15; + or.b32 %r26, %r2539, %r16; + or.b32 %r27, %r2539, %r17; + or.b32 %r28, %r2539, %r18; + or.b32 %r29, %r2539, %r19; + or.b32 %r30, %r2539, %r20; + or.b32 %r31, %r2539, %r21; +$L__tmp20: + .loc 1 789 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + shl.b32 %r2540, %r22, 12; + shl.b32 %r2541, %r23, 12; + shl.b32 %r2542, %r24, 12; + shl.b32 %r2543, %r25, 12; + shl.b32 %r2544, %r26, 12; + shl.b32 %r2545, %r27, 12; + shl.b32 %r2546, %r28, 12; + shl.b32 %r2547, %r29, 12; + .loc 1 789 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + mad.wide.s32 %rd240, %r2540, 2, %rd235; + mad.wide.s32 %rd241, %r2541, 2, %rd235; + mad.wide.s32 %rd242, %r2542, 2, %rd235; + mad.wide.s32 %rd243, %r2543, 2, %rd235; + mad.wide.s32 %rd244, %r2544, 2, %rd235; + mad.wide.s32 %rd245, %r2545, 2, %rd235; + mad.wide.s32 %rd246, %r2546, 2, %rd235; + mad.wide.s32 %rd247, %r2547, 2, %rd235; + .loc 1 789 56 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:56 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + shl.b32 %r2548, %r9, 3; + and.b32 %r2549, %r2548, 120; + .loc 1 789 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + cvt.u64.u32 %rd13, %r2549; + mul.wide.u32 %rd248, %r2549, 2; + add.s64 %rd196, %rd240, %rd248; + add.s64 %rd197, %rd241, %rd248; + add.s64 %rd198, %rd242, %rd248; + add.s64 %rd199, %rd243, %rd248; + add.s64 %rd200, %rd244, %rd248; + add.s64 %rd201, %rd245, %rd248; + add.s64 %rd202, %rd246, %rd248; + add.s64 %rd203, %rd247, %rd248; + .loc 1 797 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + setp.lt.s32 %p12, %r22, %r2338; + setp.lt.s32 %p13, %r23, %r2338; + setp.lt.s32 %p14, %r24, %r2338; + setp.lt.s32 %p15, %r25, %r2338; + setp.lt.s32 %p16, %r26, %r2338; + setp.lt.s32 %p17, %r27, %r2338; + setp.lt.s32 %p18, %r28, %r2338; + setp.lt.s32 %p19, %r29, %r2338; + mov.b32 %r14256, 0; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:178:107 ] + // begin inline asm + mov.u32 %r2353, %r14256; + mov.u32 %r2354, %r14256; + mov.u32 %r2355, %r14256; + mov.u32 %r2356, %r14256; + @%p12 ld.global.v4.b32 { %r2353, %r2354, %r2355, %r2356 }, [ %rd196 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2361, %r14256; + mov.u32 %r2362, %r14256; + mov.u32 %r2363, %r14256; + mov.u32 %r2364, %r14256; + @%p13 ld.global.v4.b32 { %r2361, %r2362, %r2363, %r2364 }, [ %rd197 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2369, %r14256; + mov.u32 %r2370, %r14256; + mov.u32 %r2371, %r14256; + mov.u32 %r2372, %r14256; + @%p14 ld.global.v4.b32 { %r2369, %r2370, %r2371, %r2372 }, [ %rd198 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2377, %r14256; + mov.u32 %r2378, %r14256; + mov.u32 %r2379, %r14256; + mov.u32 %r2380, %r14256; + @%p15 ld.global.v4.b32 { %r2377, %r2378, %r2379, %r2380 }, [ %rd199 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2385, %r14256; + mov.u32 %r2386, %r14256; + mov.u32 %r2387, %r14256; + mov.u32 %r2388, %r14256; + @%p16 ld.global.v4.b32 { %r2385, %r2386, %r2387, %r2388 }, [ %rd200 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2393, %r14256; + mov.u32 %r2394, %r14256; + mov.u32 %r2395, %r14256; + mov.u32 %r2396, %r14256; + @%p17 ld.global.v4.b32 { %r2393, %r2394, %r2395, %r2396 }, [ %rd201 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2401, %r14256; + mov.u32 %r2402, %r14256; + mov.u32 %r2403, %r14256; + mov.u32 %r2404, %r14256; + @%p18 ld.global.v4.b32 { %r2401, %r2402, %r2403, %r2404 }, [ %rd202 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2409, %r14256; + mov.u32 %r2410, %r14256; + mov.u32 %r2411, %r14256; + mov.u32 %r2412, %r14256; + @%p19 ld.global.v4.b32 { %r2409, %r2410, %r2411, %r2412 }, [ %rd203 + 0 ]; + // end inline asm + shl.b32 %r2550, %r9, 4; + and.b32 %r2551, %r2550, 112; + shl.b32 %r2552, %r11, 3; + and.b32 %r2553, %r9, 112; + and.b32 %r2554, %r9, 8; + shl.b32 %r2555, %r2554, 11; + or.b32 %r2556, %r2551, %r2552; + xor.b32 %r2557, %r2556, %r2553; + or.b32 %r2558, %r2557, %r2555; + mov.b32 %r2559, global_smem; + add.s32 %r2560, %r2559, %r2558; + st.shared.v4.b32 [%r2560+98304], {%r2353, %r2354, %r2355, %r2356}; + st.shared.v4.b32 [%r2560+100352], {%r2361, %r2362, %r2363, %r2364}; + st.shared.v4.b32 [%r2560+102400], {%r2369, %r2370, %r2371, %r2372}; + st.shared.v4.b32 [%r2560+104448], {%r2377, %r2378, %r2379, %r2380}; + st.shared.v4.b32 [%r2560+106496], {%r2385, %r2386, %r2387, %r2388}; + st.shared.v4.b32 [%r2560+108544], {%r2393, %r2394, %r2395, %r2396}; + st.shared.v4.b32 [%r2560+110592], {%r2401, %r2402, %r2403, %r2404}; + st.shared.v4.b32 [%r2560+112640], {%r2409, %r2410, %r2411, %r2412}; +$L__tmp21: + .loc 1 789 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:179:111 ] + shl.b32 %r2561, %r22, 7; + shl.b32 %r2562, %r23, 7; + shl.b32 %r2563, %r24, 7; + shl.b32 %r2564, %r25, 7; + shl.b32 %r2565, %r26, 7; + shl.b32 %r2566, %r27, 7; + shl.b32 %r2567, %r28, 7; + shl.b32 %r2568, %r29, 7; + .loc 1 789 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:179:111 ] + mad.wide.s32 %rd249, %r2561, 2, %rd236; + mad.wide.s32 %rd250, %r2562, 2, %rd236; + mad.wide.s32 %rd251, %r2563, 2, %rd236; + mad.wide.s32 %rd252, %r2564, 2, %rd236; + mad.wide.s32 %rd253, %r2565, 2, %rd236; + mad.wide.s32 %rd254, %r2566, 2, %rd236; + mad.wide.s32 %rd255, %r2567, 2, %rd236; + mad.wide.s32 %rd256, %r2568, 2, %rd236; + .loc 1 789 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:789:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:179:111 ] + add.s64 %rd204, %rd249, %rd248; + add.s64 %rd205, %rd250, %rd248; + add.s64 %rd206, %rd251, %rd248; + add.s64 %rd207, %rd252, %rd248; + add.s64 %rd208, %rd253, %rd248; + add.s64 %rd209, %rd254, %rd248; + add.s64 %rd210, %rd255, %rd248; + add.s64 %rd211, %rd256, %rd248; + .loc 1 797 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:797:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:179:111 ] + // begin inline asm + mov.u32 %r2417, %r14256; + mov.u32 %r2418, %r14256; + mov.u32 %r2419, %r14256; + mov.u32 %r2420, %r14256; + @%p12 ld.global.v4.b32 { %r2417, %r2418, %r2419, %r2420 }, [ %rd204 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2425, %r14256; + mov.u32 %r2426, %r14256; + mov.u32 %r2427, %r14256; + mov.u32 %r2428, %r14256; + @%p13 ld.global.v4.b32 { %r2425, %r2426, %r2427, %r2428 }, [ %rd205 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2433, %r14256; + mov.u32 %r2434, %r14256; + mov.u32 %r2435, %r14256; + mov.u32 %r2436, %r14256; + @%p14 ld.global.v4.b32 { %r2433, %r2434, %r2435, %r2436 }, [ %rd206 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2441, %r14256; + mov.u32 %r2442, %r14256; + mov.u32 %r2443, %r14256; + mov.u32 %r2444, %r14256; + @%p15 ld.global.v4.b32 { %r2441, %r2442, %r2443, %r2444 }, [ %rd207 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2449, %r14256; + mov.u32 %r2450, %r14256; + mov.u32 %r2451, %r14256; + mov.u32 %r2452, %r14256; + @%p16 ld.global.v4.b32 { %r2449, %r2450, %r2451, %r2452 }, [ %rd208 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2457, %r14256; + mov.u32 %r2458, %r14256; + mov.u32 %r2459, %r14256; + mov.u32 %r2460, %r14256; + @%p17 ld.global.v4.b32 { %r2457, %r2458, %r2459, %r2460 }, [ %rd209 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2465, %r14256; + mov.u32 %r2466, %r14256; + mov.u32 %r2467, %r14256; + mov.u32 %r2468, %r14256; + @%p18 ld.global.v4.b32 { %r2465, %r2466, %r2467, %r2468 }, [ %rd210 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2473, %r14256; + mov.u32 %r2474, %r14256; + mov.u32 %r2475, %r14256; + mov.u32 %r2476, %r14256; + @%p19 ld.global.v4.b32 { %r2473, %r2474, %r2475, %r2476 }, [ %rd211 + 0 ]; + // end inline asm + st.shared.v4.b32 [%r2560+131072], {%r2417, %r2418, %r2419, %r2420}; + st.shared.v4.b32 [%r2560+133120], {%r2425, %r2426, %r2427, %r2428}; + st.shared.v4.b32 [%r2560+135168], {%r2433, %r2434, %r2435, %r2436}; + st.shared.v4.b32 [%r2560+137216], {%r2441, %r2442, %r2443, %r2444}; + st.shared.v4.b32 [%r2560+139264], {%r2449, %r2450, %r2451, %r2452}; + st.shared.v4.b32 [%r2560+141312], {%r2457, %r2458, %r2459, %r2460}; + st.shared.v4.b32 [%r2560+143360], {%r2465, %r2466, %r2467, %r2468}; + st.shared.v4.b32 [%r2560+145408], {%r2473, %r2474, %r2475, %r2476}; +$L__tmp22: + .loc 1 188 58 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:188:58 + setp.lt.s32 %p28, %r30, %r2338; + setp.lt.s32 %p29, %r31, %r2338; + .loc 1 188 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:188:34 + mul.wide.s32 %rd257, %r30, 4; + add.s64 %rd212, %rd239, %rd257; + cvt.s64.s32 %rd258, %r2539; + cvt.u64.u32 %rd259, %r20; + or.b64 %rd260, %rd258, %rd259; + shl.b64 %rd261, %rd260, 2; + add.s64 %rd262, %rd239, %rd261; + add.s64 %rd213, %rd262, 32; + .loc 1 188 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:188:25 + // begin inline asm + mov.u32 %r2481, 0x0; + @%p28 ld.global.b32 { %r2481 }, [ %rd212 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2482, 0x0; + @%p29 ld.global.b32 { %r2482 }, [ %rd213 + 0 ]; + // end inline asm + .loc 1 189 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:189:33 + add.s64 %rd214, %rd238, %rd257; + add.s64 %rd263, %rd238, %rd261; + add.s64 %rd215, %rd263, 32; + .loc 1 189 26 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:189:26 + // begin inline asm + mov.u32 %r2483, 0x0; + @%p28 ld.global.b32 { %r2483 }, [ %rd214 + 0 ]; + // end inline asm + // begin inline asm + mov.u32 %r2484, 0x0; + @%p29 ld.global.b32 { %r2484 }, [ %rd215 + 0 ]; + // end inline asm + .loc 1 190 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:190:30 + setp.eq.f32 %p32, %r2483, 0fFF800000; + setp.eq.f32 %p33, %r2484, 0fFF800000; + .loc 1 190 50 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:190:50 + selp.f32 %r34, 0f00000000, %r2483, %p32; + selp.f32 %r35, 0f00000000, %r2484, %p33; + .loc 1 195 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:195:30 + cvt.s64.s32 %rd14, %r2531; + mul.wide.s32 %rd264, %r2531, 4; + add.s64 %rd216, %rd184, %rd264; + .loc 1 196 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:196:27 + // begin inline asm + mov.u32 %r2485, 0x0; + ld.global.b32 { %r2485 }, [ %rd216 + 0 ]; + // end inline asm + .loc 1 196 41 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:196:41 + shl.b32 %r36, %r2485, 7; + .loc 1 197 53 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:197:53 + add.s64 %rd217, %rd183, %rd264; + .loc 1 197 39 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:197:39 + // begin inline asm + mov.u32 %r2486, 0x0; + ld.global.b32 { %r2486 }, [ %rd217 + 0 ]; + // end inline asm + .loc 1 199 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:199:42 + and.b32 %r38, %r9, 3; + shl.b32 %r41, %r38, 1; + or.b32 %r40, %r41, 1; + or.b32 %r44, %r41, 9; + or.b32 %r43, %r41, 8; + or.b32 %r48, %r41, 16; + or.b32 %r47, %r41, 17; + or.b32 %r46, %r41, 24; + or.b32 %r45, %r41, 25; + or.b32 %r56, %r41, 32; + or.b32 %r55, %r41, 33; + or.b32 %r54, %r41, 40; + or.b32 %r53, %r41, 41; + or.b32 %r52, %r41, 48; + or.b32 %r51, %r41, 49; + or.b32 %r50, %r41, 56; + or.b32 %r49, %r41, 57; + .loc 1 199 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:199:29 + or.b32 %r2569, %r36, %r12; + or.b32 %r2570, %r36, %r13; + or.b32 %r2571, %r36, %r14; + or.b32 %r2572, %r36, %r15; +$L__tmp23: + .loc 1 390 37 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:37 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r2573, %r2569, 10; + shl.b32 %r2574, %r2570, 10; + shl.b32 %r2575, %r2571, 10; + shl.b32 %r2576, %r2572, 10; + .loc 1 390 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.wide.s32 %rd265, %r2573, 2; + add.s64 %rd266, %rd1, %rd265; + mul.wide.s32 %rd267, %r2574, 2; + add.s64 %rd268, %rd1, %rd267; + mul.wide.s32 %rd269, %r2575, 2; + add.s64 %rd270, %rd1, %rd269; + mul.wide.s32 %rd271, %r2576, 2; + add.s64 %rd272, %rd1, %rd271; + .loc 1 390 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd218, %rd266, %rd248; + add.s64 %rd219, %rd268, %rd248; + add.s64 %rd220, %rd270, %rd248; + add.s64 %rd221, %rd272, %rd248; + .loc 1 391 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:391:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd273, %rd2, %rd265; + add.s64 %rd274, %rd2, %rd267; + add.s64 %rd275, %rd2, %rd269; + add.s64 %rd276, %rd2, %rd271; + .loc 1 391 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:391:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd222, %rd273, %rd248; + add.s64 %rd223, %rd274, %rd248; + add.s64 %rd224, %rd275, %rd248; + add.s64 %rd225, %rd276, %rd248; + .loc 1 395 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:395:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r2577, %r2486, 1; + .loc 2 41 22 // standard.py:41:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r2578, %r2339, 63; + .loc 2 41 28 // standard.py:41:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r2579, %r2578, 31; + shr.u32 %r2580, %r2579, 26; + add.s32 %r2581, %r2578, %r2580; + shr.s32 %r2582, %r2581, 6; + .loc 1 395 101 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:395:101 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + max.s32 %r57, %r2582, 1; + .loc 1 395 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:395:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + min.s32 %r58, %r2577, %r57; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p34, %r2577, 1; + setp.gt.s32 %p35, %r2577, 0; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p36, %r2569, %r2339; + setp.lt.s32 %p37, %r2570, %r2339; + setp.lt.s32 %p38, %r2571, %r2339; + setp.lt.s32 %p39, %r2572, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r2583, %r2554, 10; + or.b32 %r59, %r2557, %r2583; + add.s32 %r4884, %r2559, %r59; + selp.b32 %r2584, 16, 0, %p36; + selp.b32 %r2496, %r2584, 0, %p35; + // begin inline asm + cp.async.cg.shared.global [ %r4884 + 0 ], [ %rd218 + 0 ], 0x10, %r2496; + // end inline asm + add.s32 %r4886, %r4884, 2048; + selp.b32 %r2585, 16, 0, %p37; + selp.b32 %r2498, %r2585, 0, %p35; + // begin inline asm + cp.async.cg.shared.global [ %r4886 + 0 ], [ %rd219 + 0 ], 0x10, %r2498; + // end inline asm + add.s32 %r4888, %r4884, 4096; + selp.b32 %r2586, 16, 0, %p38; + selp.b32 %r2500, %r2586, 0, %p35; + // begin inline asm + cp.async.cg.shared.global [ %r4888 + 0 ], [ %rd220 + 0 ], 0x10, %r2500; + // end inline asm + add.s32 %r4890, %r4884, 6144; + selp.b32 %r2587, 16, 0, %p39; + selp.b32 %r2502, %r2587, 0, %p35; + // begin inline asm + cp.async.cg.shared.global [ %r4890 + 0 ], [ %rd221 + 0 ], 0x10, %r2502; + // end inline asm + cp.async.commit_group; + add.s32 %r2495, %r4884, 49152; + // begin inline asm + cp.async.cg.shared.global [ %r2495 + 0 ], [ %rd222 + 0 ], 0x10, %r2496; + // end inline asm + add.s32 %r2497, %r4884, 51200; + // begin inline asm + cp.async.cg.shared.global [ %r2497 + 0 ], [ %rd223 + 0 ], 0x10, %r2498; + // end inline asm + add.s32 %r2499, %r4884, 53248; + // begin inline asm + cp.async.cg.shared.global [ %r2499 + 0 ], [ %rd224 + 0 ], 0x10, %r2500; + // end inline asm + add.s32 %r2501, %r4884, 55296; + // begin inline asm + cp.async.cg.shared.global [ %r2501 + 0 ], [ %rd225 + 0 ], 0x10, %r2502; + // end inline asm + cp.async.commit_group; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.gt.s32 %p40, %r58, 1; + .loc 1 414 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd1111, %rd218, 131072; + add.s64 %rd1110, %rd219, 131072; + add.s64 %rd1109, %rd220, 131072; + add.s64 %rd1108, %rd221, 131072; + .loc 1 415 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:415:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd1107, %rd222, 131072; + add.s64 %rd1106, %rd223, 131072; + add.s64 %rd1105, %rd224, 131072; + add.s64 %rd1104, %rd225, 131072; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r14191, %r2569, 64; + or.b32 %r14190, %r2570, 64; + or.b32 %r14189, %r2571, 64; + or.b32 %r14188, %r2572, 64; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p41, %r14191, %r2339; + setp.lt.s32 %p42, %r14190, %r2339; + setp.lt.s32 %p43, %r14189, %r2339; + setp.lt.s32 %p44, %r14188, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + bar.sync 0; + add.s32 %r2503, %r4884, 16384; + selp.b32 %r2588, 16, 0, %p41; + selp.b32 %r2512, %r2588, 0, %p40; + // begin inline asm + cp.async.cg.shared.global [ %r2503 + 0 ], [ %rd1111 + 0 ], 0x10, %r2512; + // end inline asm + add.s32 %r2505, %r4884, 18432; + selp.b32 %r2589, 16, 0, %p42; + selp.b32 %r2514, %r2589, 0, %p40; + // begin inline asm + cp.async.cg.shared.global [ %r2505 + 0 ], [ %rd1110 + 0 ], 0x10, %r2514; + // end inline asm + add.s32 %r2507, %r4884, 20480; + selp.b32 %r2590, 16, 0, %p43; + selp.b32 %r2516, %r2590, 0, %p40; + // begin inline asm + cp.async.cg.shared.global [ %r2507 + 0 ], [ %rd1109 + 0 ], 0x10, %r2516; + // end inline asm + add.s32 %r2509, %r4884, 22528; + selp.b32 %r2591, 16, 0, %p44; + selp.b32 %r2518, %r2591, 0, %p40; + // begin inline asm + cp.async.cg.shared.global [ %r2509 + 0 ], [ %rd1108 + 0 ], 0x10, %r2518; + // end inline asm + cp.async.commit_group; + add.s32 %r2511, %r4884, 65536; + // begin inline asm + cp.async.cg.shared.global [ %r2511 + 0 ], [ %rd1107 + 0 ], 0x10, %r2512; + // end inline asm + add.s32 %r2513, %r4884, 67584; + // begin inline asm + cp.async.cg.shared.global [ %r2513 + 0 ], [ %rd1106 + 0 ], 0x10, %r2514; + // end inline asm + add.s32 %r2515, %r4884, 69632; + // begin inline asm + cp.async.cg.shared.global [ %r2515 + 0 ], [ %rd1105 + 0 ], 0x10, %r2516; + // end inline asm + add.s32 %r2517, %r4884, 71680; + // begin inline asm + cp.async.cg.shared.global [ %r2517 + 0 ], [ %rd1104 + 0 ], 0x10, %r2518; + // end inline asm + cp.async.commit_group; + .loc 1 459 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:459:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + // begin inline asm + fence.proxy.async.shared::cta; + // end inline asm + mov.b32 %r14192, 0f00000000; + mov.b32 %r14193, %r14192; + mov.b32 %r14194, %r14192; + mov.b32 %r14195, %r14192; + mov.b32 %r14196, %r14192; + mov.b32 %r14197, %r14192; + mov.b32 %r14198, %r14192; + mov.b32 %r14199, %r14192; + mov.b32 %r14200, %r14192; + mov.b32 %r14201, %r14192; + mov.b32 %r14202, %r14192; + mov.b32 %r14203, %r14192; + mov.b32 %r14204, %r14192; + mov.b32 %r14205, %r14192; + mov.b32 %r14206, %r14192; + mov.b32 %r14207, %r14192; + mov.b32 %r14208, %r14192; + mov.b32 %r14209, %r14192; + mov.b32 %r14210, %r14192; + mov.b32 %r14211, %r14192; + mov.b32 %r14212, %r14192; + mov.b32 %r14213, %r14192; + mov.b32 %r14214, %r14192; + mov.b32 %r14215, %r14192; + mov.b32 %r14216, %r14192; + mov.b32 %r14217, %r14192; + mov.b32 %r14218, %r14192; + mov.b32 %r14219, %r14192; + mov.b32 %r14220, %r14192; + mov.b32 %r14221, %r14192; + mov.b32 %r14222, %r14192; + mov.b32 %r14223, %r14192; + mov.b32 %r14224, %r14192; + mov.b32 %r14225, %r14192; + mov.b32 %r14226, %r14192; + mov.b32 %r14227, %r14192; + mov.b32 %r14228, %r14192; + mov.b32 %r14229, %r14192; + mov.b32 %r14230, %r14192; + mov.b32 %r14231, %r14192; + mov.b32 %r14232, %r14192; + mov.b32 %r14233, %r14192; + mov.b32 %r14234, %r14192; + mov.b32 %r14235, %r14192; + mov.b32 %r14236, %r14192; + mov.b32 %r14237, %r14192; + mov.b32 %r14238, %r14192; + mov.b32 %r14239, %r14192; + mov.b32 %r14240, %r14192; + mov.b32 %r14241, %r14192; + mov.b32 %r14242, %r14192; + mov.b32 %r14243, %r14192; + mov.b32 %r14244, %r14192; + mov.b32 %r14245, %r14192; + mov.b32 %r14246, %r14192; + mov.b32 %r14247, %r14192; + mov.b32 %r14248, %r14192; + mov.b32 %r14249, %r14192; + mov.b32 %r14250, %r14192; + mov.b32 %r14251, %r14192; + mov.b32 %r14252, %r14192; + mov.b32 %r14253, %r14192; + mov.b32 %r14254, %r14192; + mov.b32 %r14255, %r14192; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + @%p34 bra $L__BB0_4; +// %bb.2: // %.lr.ph + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r99, %r31, %r2338; + .loc 1 492 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r2597, %r99, 31; + shr.u32 %r2598, %r2597, 28; + add.s32 %r2599, %r99, %r2598; + shr.s32 %r2600, %r2599, 4; + .loc 1 481 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:481:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p1, %r99, 0; + .loc 1 492 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r2601, %r99, 15; + setp.ne.b32 %p45, %r2601, 0; + .loc 1 492 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p46, %p1, %p45; + selp.b32 %r2602, -1, 0, %p46; + add.s32 %r101, %r2600, %r2602; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r105, %r30, %r2338; + .loc 1 492 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r2603, %r105, 31; + shr.u32 %r2604, %r2603, 28; + add.s32 %r2605, %r105, %r2604; + shr.s32 %r2606, %r2605, 4; + .loc 1 481 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:481:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p5, %r105, 0; + .loc 1 492 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r2607, %r105, 15; + setp.ne.b32 %p47, %r2607, 0; + .loc 1 492 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:492:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p48, %p5, %p47; + selp.b32 %r2608, -1, 0, %p48; + add.s32 %r107, %r2606, %r2608; + .loc 1 485 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:485:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.gt.s32 %p3, %r99, -1; + setp.gt.s32 %p7, %r105, -1; +$L__tmp24: + .loc 1 199 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:199:29 + or.b32 %r14258, %r36, %r49; + or.b32 %r14257, %r36, %r50; + or.b32 %r14260, %r36, %r51; + or.b32 %r14259, %r36, %r52; + or.b32 %r14262, %r36, %r53; + or.b32 %r14261, %r36, %r54; + or.b32 %r14264, %r36, %r55; + or.b32 %r14263, %r36, %r56; + or.b32 %r14266, %r36, %r45; + or.b32 %r14265, %r36, %r46; + or.b32 %r14268, %r36, %r47; + or.b32 %r14267, %r36, %r48; + or.b32 %r14269, %r36, %r43; + or.b32 %r14270, %r36, %r44; + or.b32 %r14272, %r36, %r40; + or.b32 %r14271, %r36, %r41; + add.s32 %r96, %r58, -2; + add.s32 %r97, %r58, -1; +$L__tmp25: + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + max.s32 %r98, %r58, 1; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mov.b64 %rd24, {%r2482, %r2482}; + mov.b64 %rd25, {%r2481, %r2481}; + mov.b32 %r14192, 0f00000000; + mov.b32 %r14187, 1; + mov.b32 %r14186, -1; + mov.b32 %r14185, 64; + mov.b32 %r14193, %r14192; + mov.b32 %r14194, %r14192; + mov.b32 %r14195, %r14192; + mov.b32 %r14196, %r14192; + mov.b32 %r14197, %r14192; + mov.b32 %r14198, %r14192; + mov.b32 %r14199, %r14192; + mov.b32 %r14200, %r14192; + mov.b32 %r14201, %r14192; + mov.b32 %r14202, %r14192; + mov.b32 %r14203, %r14192; + mov.b32 %r14204, %r14192; + mov.b32 %r14205, %r14192; + mov.b32 %r14206, %r14192; + mov.b32 %r14207, %r14192; + mov.b32 %r14208, %r14192; + mov.b32 %r14209, %r14192; + mov.b32 %r14210, %r14192; + mov.b32 %r14211, %r14192; + mov.b32 %r14212, %r14192; + mov.b32 %r14213, %r14192; + mov.b32 %r14214, %r14192; + mov.b32 %r14215, %r14192; + mov.b32 %r14216, %r14192; + mov.b32 %r14217, %r14192; + mov.b32 %r14218, %r14192; + mov.b32 %r14219, %r14192; + mov.b32 %r14220, %r14192; + mov.b32 %r14221, %r14192; + mov.b32 %r14222, %r14192; + mov.b32 %r14223, %r14192; + mov.b32 %r14224, %r14192; + mov.b32 %r14225, %r14192; + mov.b32 %r14226, %r14192; + mov.b32 %r14227, %r14192; + mov.b32 %r14228, %r14192; + mov.b32 %r14229, %r14192; + mov.b32 %r14230, %r14192; + mov.b32 %r14231, %r14192; + mov.b32 %r14232, %r14192; + mov.b32 %r14233, %r14192; + mov.b32 %r14234, %r14192; + mov.b32 %r14235, %r14192; + mov.b32 %r14236, %r14192; + mov.b32 %r14237, %r14192; + mov.b32 %r14238, %r14192; + mov.b32 %r14239, %r14192; + mov.b32 %r14240, %r14192; + mov.b32 %r14241, %r14192; + mov.b32 %r14242, %r14192; + mov.b32 %r14243, %r14192; + mov.b32 %r14244, %r14192; + mov.b32 %r14245, %r14192; + mov.b32 %r14246, %r14192; + mov.b32 %r14247, %r14192; + mov.b32 %r14248, %r14192; + mov.b32 %r14249, %r14192; + mov.b32 %r14250, %r14192; + mov.b32 %r14251, %r14192; + mov.b32 %r14252, %r14192; + mov.b32 %r14253, %r14192; + mov.b32 %r14254, %r14192; + mov.b32 %r14255, %r14192; +$L__BB0_3: // %__nv_exp2f.exit1516 + // =>This Inner Loop Header: Depth=1 + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p69, %r14256, %r96; + setp.lt.s32 %p67, %r14256, %r97; + add.s32 %r4267, %r14186, 1; + setp.gt.s32 %p70, %r4267, 2; + selp.b32 %r14186, 0, %r4267, %p70; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p71, %r14271, %r2339; + setp.lt.s32 %p72, %r14272, %r2339; + setp.lt.s32 %p73, %r14269, %r2339; + setp.lt.s32 %p74, %r14270, %r2339; + setp.lt.s32 %p75, %r14267, %r2339; + setp.lt.s32 %p76, %r14268, %r2339; + setp.lt.s32 %p77, %r14265, %r2339; + setp.lt.s32 %p78, %r14266, %r2339; + setp.lt.s32 %p79, %r14263, %r2339; + setp.lt.s32 %p80, %r14264, %r2339; + setp.lt.s32 %p81, %r14261, %r2339; + setp.lt.s32 %p82, %r14262, %r2339; + setp.lt.s32 %p83, %r14259, %r2339; + setp.lt.s32 %p84, %r14260, %r2339; + setp.lt.s32 %p85, %r14257, %r2339; + setp.lt.s32 %p86, %r14258, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cp.async.wait_group 2; + bar.sync 0; + shl.b32 %r4268, %r14186, 14; + add.s32 %r3162, %r2559, %r4268; + .loc 1 459 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:459:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shfl.sync.idx.b32 %r4270, %r10, 0, 31, -1; + wgmma.fence.sync.aligned; + shl.b32 %r4271, %r4270, 11; + and.b32 %r4272, %r4271, 8192; + add.s32 %r3121, %r2559, 98304; + add.s32 %r4273, %r4272, %r3121; + bfe.u32 %r4274, %r4273, 4, 14; + cvt.u64.u32 %rd327, %r4274; + or.b64 %rd277, %rd327, 4611686293372403712; + bfe.u32 %r4275, %r3162, 4, 14; + cvt.u64.u32 %rd328, %r4275; + or.b64 %rd278, %rd328, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd277, %rd278, 0, 1, 1, 0, 0; + // end inline asm + or.b32 %r4276, %r4272, 32; + add.s32 %r4277, %r4276, %r3121; + bfe.u32 %r4278, %r4277, 4, 14; + cvt.u64.u32 %rd329, %r4278; + or.b64 %rd279, %rd329, 4611686293372403712; + add.s32 %r4279, %r3162, 32; + bfe.u32 %r4280, %r4279, 4, 14; + cvt.u64.u32 %rd330, %r4280; + or.b64 %rd280, %rd330, 4611686293338849280; + mov.pred %p49, -1; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd279, %rd280, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4281, %r4272, 64; + add.s32 %r4282, %r4281, %r3121; + bfe.u32 %r4283, %r4282, 4, 14; + cvt.u64.u32 %rd331, %r4283; + or.b64 %rd281, %rd331, 4611686293372403712; + add.s32 %r4284, %r3162, 64; + bfe.u32 %r4285, %r4284, 4, 14; + cvt.u64.u32 %rd332, %r4285; + or.b64 %rd282, %rd332, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd281, %rd282, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4286, %r4272, 96; + add.s32 %r4287, %r4286, %r3121; + bfe.u32 %r4288, %r4287, 4, 14; + cvt.u64.u32 %rd333, %r4288; + or.b64 %rd283, %rd333, 4611686293372403712; + add.s32 %r4289, %r3162, 96; + bfe.u32 %r4290, %r4289, 4, 14; + cvt.u64.u32 %rd334, %r4290; + or.b64 %rd284, %rd334, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd283, %rd284, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4291, %r4272, 16384; + add.s32 %r4292, %r4291, %r3121; + bfe.u32 %r4293, %r4292, 4, 14; + cvt.u64.u32 %rd335, %r4293; + or.b64 %rd285, %rd335, 4611686293372403712; + add.s32 %r4294, %r3162, 8192; + bfe.u32 %r4295, %r4294, 4, 14; + cvt.u64.u32 %rd336, %r4295; + or.b64 %rd286, %rd336, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd285, %rd286, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4296, %r4272, 16416; + add.s32 %r4297, %r4296, %r3121; + bfe.u32 %r4298, %r4297, 4, 14; + cvt.u64.u32 %rd337, %r4298; + or.b64 %rd287, %rd337, 4611686293372403712; + add.s32 %r4299, %r3162, 8224; + bfe.u32 %r4300, %r4299, 4, 14; + cvt.u64.u32 %rd338, %r4300; + or.b64 %rd288, %rd338, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd287, %rd288, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4301, %r4272, 16448; + add.s32 %r4302, %r4301, %r3121; + bfe.u32 %r4303, %r4302, 4, 14; + cvt.u64.u32 %rd339, %r4303; + or.b64 %rd289, %rd339, 4611686293372403712; + add.s32 %r4304, %r3162, 8256; + bfe.u32 %r4305, %r4304, 4, 14; + cvt.u64.u32 %rd340, %r4305; + or.b64 %rd290, %rd340, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd289, %rd290, %p49, 1, 1, 0, 0; + // end inline asm + or.b32 %r4306, %r4272, 16480; + add.s32 %r4307, %r4306, %r3121; + bfe.u32 %r4308, %r4307, 4, 14; + cvt.u64.u32 %rd341, %r4308; + or.b64 %rd291, %rd341, 4611686293372403712; + add.s32 %r4309, %r3162, 8288; + bfe.u32 %r4310, %r4309, 4, 14; + cvt.u64.u32 %rd342, %r4310; + or.b64 %rd292, %rd342, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736}, %rd291, %rd292, %p49, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r3678, 0; + mov.b32 %r3124, %r3162; + mov.b32 %r3122, %r3678; + mov.b32 %r3123, %r3678; + mov.b32 %r3125, %r3678; + mov.b32 %r3126, %r3678; + // begin inline asm + // wait for regs: %r2705,%r2706,%r2707,%r2708,%r2709,%r2710,%r2711,%r2712,%r2713,%r2714,%r2715,%r2716,%r2717,%r2718,%r2719,%r2720,%r2721,%r2722,%r2723,%r2724,%r2725,%r2726,%r2727,%r2728,%r2729,%r2730,%r2731,%r2732,%r2733,%r2734,%r2735,%r2736,%r3121,%r3122,%r3123,%r3124,%r3125,%r3126 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 461 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:461:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4311, %r2705, 0f3DB504F3; + mul.f32 %r4312, %r2706, 0f3DB504F3; + mul.f32 %r4313, %r2707, 0f3DB504F3; + mul.f32 %r4314, %r2708, 0f3DB504F3; + mul.f32 %r4315, %r2709, 0f3DB504F3; + mul.f32 %r4316, %r2710, 0f3DB504F3; + mul.f32 %r4317, %r2711, 0f3DB504F3; + mul.f32 %r4318, %r2712, 0f3DB504F3; + mul.f32 %r4319, %r2713, 0f3DB504F3; + mul.f32 %r4320, %r2714, 0f3DB504F3; + mul.f32 %r4321, %r2715, 0f3DB504F3; + mul.f32 %r4322, %r2716, 0f3DB504F3; + mul.f32 %r4323, %r2717, 0f3DB504F3; + mul.f32 %r4324, %r2718, 0f3DB504F3; + mul.f32 %r4325, %r2719, 0f3DB504F3; + mul.f32 %r4326, %r2720, 0f3DB504F3; + mul.f32 %r4327, %r2721, 0f3DB504F3; + mul.f32 %r4328, %r2722, 0f3DB504F3; + mul.f32 %r4329, %r2723, 0f3DB504F3; + mul.f32 %r4330, %r2724, 0f3DB504F3; + mul.f32 %r4331, %r2725, 0f3DB504F3; + mul.f32 %r4332, %r2726, 0f3DB504F3; + mul.f32 %r4333, %r2727, 0f3DB504F3; + mul.f32 %r4334, %r2728, 0f3DB504F3; + mul.f32 %r4335, %r2729, 0f3DB504F3; + mul.f32 %r4336, %r2730, 0f3DB504F3; + mul.f32 %r4337, %r2731, 0f3DB504F3; + mul.f32 %r4338, %r2732, 0f3DB504F3; + mul.f32 %r4339, %r2733, 0f3DB504F3; + mul.f32 %r4340, %r2734, 0f3DB504F3; + mul.f32 %r4341, %r2735, 0f3DB504F3; + mul.f32 %r4342, %r2736, 0f3DB504F3; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4343, %r14272, %r2339; + rem.s32 %r4344, %r14271, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p87, %r4344, %r99; + setp.le.s32 %p88, %r4343, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p89, %p1, %p88; + and.pred %p90, %p1, %p87; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p91, %r4343, 0; + setp.lt.s32 %p92, %r4344, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p93, %p3, %p92; + and.pred %p94, %p3, %p91; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4345, %r4344, %r99; + or.b32 %r4346, %r4343, %r99; + setp.gt.s32 %p95, %r4346, -1; + setp.gt.s32 %p96, %r4345, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4347, %r4344, 15; + and.b32 %r4348, %r4343, 15; + setp.ne.b32 %p97, %r4348, 0; + setp.ne.b32 %p98, %r4347, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4349, %r4343, 31; + shr.u32 %r4350, %r4349, 28; + add.s32 %r4351, %r4343, %r4350; + shr.s32 %r4352, %r4351, 4; + shr.s32 %r4353, %r4344, 31; + shr.u32 %r4354, %r4353, 28; + add.s32 %r4355, %r4344, %r4354; + shr.s32 %r4356, %r4355, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p99, %p92, %p98; + and.pred %p100, %p91, %p97; + selp.b32 %r4357, -1, 0, %p100; + selp.b32 %r4358, -1, 0, %p99; + add.s32 %r4359, %r4356, %r4358; + add.s32 %r4360, %r4352, %r4357; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p101, %r101, %r4360; + setp.eq.b32 %p102, %r101, %r4359; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p103, %p96, %p102; + and.pred %p104, %p95, %p101; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p105, %p94, %p104; + or.pred %p106, %p93, %p103; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p107, %p90, %p106; + or.pred %p108, %p89, %p105; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p109, %r4344, %r105; + setp.le.s32 %p110, %r4343, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p111, %p5, %p110; + and.pred %p112, %p5, %p109; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p113, %p7, %p92; + and.pred %p114, %p7, %p91; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4361, %r4344, %r105; + or.b32 %r4362, %r4343, %r105; + setp.gt.s32 %p115, %r4362, -1; + setp.gt.s32 %p116, %r4361, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p117, %r107, %r4360; + setp.eq.b32 %p118, %r107, %r4359; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p119, %p116, %p118; + and.pred %p120, %p115, %p117; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p121, %p114, %p120; + or.pred %p122, %p113, %p119; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p123, %p112, %p122; + or.pred %p124, %p111, %p121; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p125, %p124, %p72; + and.pred %p126, %p123, %p71; + and.pred %p127, %p108, %p72; + and.pred %p128, %p107, %p71; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4363, %r14270, %r2339; + rem.s32 %r4364, %r14269, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p129, %r4364, %r99; + setp.le.s32 %p130, %r4363, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p131, %p1, %p130; + and.pred %p132, %p1, %p129; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p133, %r4363, 0; + setp.lt.s32 %p134, %r4364, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p135, %p3, %p134; + and.pred %p136, %p3, %p133; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4365, %r4364, %r99; + or.b32 %r4366, %r4363, %r99; + setp.gt.s32 %p137, %r4366, -1; + setp.gt.s32 %p138, %r4365, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4367, %r4364, 15; + and.b32 %r4368, %r4363, 15; + setp.ne.b32 %p139, %r4368, 0; + setp.ne.b32 %p140, %r4367, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4369, %r4363, 31; + shr.u32 %r4370, %r4369, 28; + add.s32 %r4371, %r4363, %r4370; + shr.s32 %r4372, %r4371, 4; + shr.s32 %r4373, %r4364, 31; + shr.u32 %r4374, %r4373, 28; + add.s32 %r4375, %r4364, %r4374; + shr.s32 %r4376, %r4375, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p141, %p134, %p140; + and.pred %p142, %p133, %p139; + selp.b32 %r4377, -1, 0, %p142; + selp.b32 %r4378, -1, 0, %p141; + add.s32 %r4379, %r4376, %r4378; + add.s32 %r4380, %r4372, %r4377; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p143, %r101, %r4380; + setp.eq.b32 %p144, %r101, %r4379; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p145, %p138, %p144; + and.pred %p146, %p137, %p143; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p147, %p136, %p146; + or.pred %p148, %p135, %p145; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p149, %p132, %p148; + or.pred %p150, %p131, %p147; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p151, %r4364, %r105; + setp.le.s32 %p152, %r4363, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p153, %p5, %p152; + and.pred %p154, %p5, %p151; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p155, %p7, %p134; + and.pred %p156, %p7, %p133; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4381, %r4364, %r105; + or.b32 %r4382, %r4363, %r105; + setp.gt.s32 %p157, %r4382, -1; + setp.gt.s32 %p158, %r4381, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p159, %r107, %r4380; + setp.eq.b32 %p160, %r107, %r4379; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p161, %p158, %p160; + and.pred %p162, %p157, %p159; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p163, %p156, %p162; + or.pred %p164, %p155, %p161; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p165, %p154, %p164; + or.pred %p166, %p153, %p163; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p167, %p166, %p74; + and.pred %p168, %p165, %p73; + and.pred %p169, %p150, %p74; + and.pred %p170, %p149, %p73; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4383, %r14268, %r2339; + rem.s32 %r4384, %r14267, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p171, %r4384, %r99; + setp.le.s32 %p172, %r4383, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p173, %p1, %p172; + and.pred %p174, %p1, %p171; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p175, %r4383, 0; + setp.lt.s32 %p176, %r4384, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p177, %p3, %p176; + and.pred %p178, %p3, %p175; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4385, %r4384, %r99; + or.b32 %r4386, %r4383, %r99; + setp.gt.s32 %p179, %r4386, -1; + setp.gt.s32 %p180, %r4385, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4387, %r4384, 15; + and.b32 %r4388, %r4383, 15; + setp.ne.b32 %p181, %r4388, 0; + setp.ne.b32 %p182, %r4387, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4389, %r4383, 31; + shr.u32 %r4390, %r4389, 28; + add.s32 %r4391, %r4383, %r4390; + shr.s32 %r4392, %r4391, 4; + shr.s32 %r4393, %r4384, 31; + shr.u32 %r4394, %r4393, 28; + add.s32 %r4395, %r4384, %r4394; + shr.s32 %r4396, %r4395, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p183, %p176, %p182; + and.pred %p184, %p175, %p181; + selp.b32 %r4397, -1, 0, %p184; + selp.b32 %r4398, -1, 0, %p183; + add.s32 %r4399, %r4396, %r4398; + add.s32 %r4400, %r4392, %r4397; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p185, %r101, %r4400; + setp.eq.b32 %p186, %r101, %r4399; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p187, %p180, %p186; + and.pred %p188, %p179, %p185; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p189, %p178, %p188; + or.pred %p190, %p177, %p187; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p191, %p174, %p190; + or.pred %p192, %p173, %p189; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p193, %r4384, %r105; + setp.le.s32 %p194, %r4383, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p195, %p5, %p194; + and.pred %p196, %p5, %p193; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p197, %p7, %p176; + and.pred %p198, %p7, %p175; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4401, %r4384, %r105; + or.b32 %r4402, %r4383, %r105; + setp.gt.s32 %p199, %r4402, -1; + setp.gt.s32 %p200, %r4401, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p201, %r107, %r4400; + setp.eq.b32 %p202, %r107, %r4399; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p203, %p200, %p202; + and.pred %p204, %p199, %p201; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p205, %p198, %p204; + or.pred %p206, %p197, %p203; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p207, %p196, %p206; + or.pred %p208, %p195, %p205; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p209, %p208, %p76; + and.pred %p210, %p207, %p75; + and.pred %p211, %p192, %p76; + and.pred %p212, %p191, %p75; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4403, %r14266, %r2339; + rem.s32 %r4404, %r14265, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p213, %r4404, %r99; + setp.le.s32 %p214, %r4403, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p215, %p1, %p214; + and.pred %p216, %p1, %p213; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p217, %r4403, 0; + setp.lt.s32 %p218, %r4404, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p219, %p3, %p218; + and.pred %p220, %p3, %p217; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4405, %r4404, %r99; + or.b32 %r4406, %r4403, %r99; + setp.gt.s32 %p221, %r4406, -1; + setp.gt.s32 %p222, %r4405, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4407, %r4404, 15; + and.b32 %r4408, %r4403, 15; + setp.ne.b32 %p223, %r4408, 0; + setp.ne.b32 %p224, %r4407, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4409, %r4403, 31; + shr.u32 %r4410, %r4409, 28; + add.s32 %r4411, %r4403, %r4410; + shr.s32 %r4412, %r4411, 4; + shr.s32 %r4413, %r4404, 31; + shr.u32 %r4414, %r4413, 28; + add.s32 %r4415, %r4404, %r4414; + shr.s32 %r4416, %r4415, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p225, %p218, %p224; + and.pred %p226, %p217, %p223; + selp.b32 %r4417, -1, 0, %p226; + selp.b32 %r4418, -1, 0, %p225; + add.s32 %r4419, %r4416, %r4418; + add.s32 %r4420, %r4412, %r4417; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p227, %r101, %r4420; + setp.eq.b32 %p228, %r101, %r4419; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p229, %p222, %p228; + and.pred %p230, %p221, %p227; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p231, %p220, %p230; + or.pred %p232, %p219, %p229; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p233, %p216, %p232; + or.pred %p234, %p215, %p231; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p235, %r4404, %r105; + setp.le.s32 %p236, %r4403, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p237, %p5, %p236; + and.pred %p238, %p5, %p235; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p239, %p7, %p218; + and.pred %p240, %p7, %p217; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4421, %r4404, %r105; + or.b32 %r4422, %r4403, %r105; + setp.gt.s32 %p241, %r4422, -1; + setp.gt.s32 %p242, %r4421, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p243, %r107, %r4420; + setp.eq.b32 %p244, %r107, %r4419; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p245, %p242, %p244; + and.pred %p246, %p241, %p243; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p247, %p240, %p246; + or.pred %p248, %p239, %p245; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p249, %p238, %p248; + or.pred %p250, %p237, %p247; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p251, %p250, %p78; + and.pred %p252, %p249, %p77; + and.pred %p253, %p234, %p78; + and.pred %p254, %p233, %p77; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4423, %r14264, %r2339; + rem.s32 %r4424, %r14263, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p255, %r4424, %r99; + setp.le.s32 %p256, %r4423, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p257, %p1, %p256; + and.pred %p258, %p1, %p255; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p259, %r4423, 0; + setp.lt.s32 %p260, %r4424, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p261, %p3, %p260; + and.pred %p262, %p3, %p259; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4425, %r4424, %r99; + or.b32 %r4426, %r4423, %r99; + setp.gt.s32 %p263, %r4426, -1; + setp.gt.s32 %p264, %r4425, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4427, %r4424, 15; + and.b32 %r4428, %r4423, 15; + setp.ne.b32 %p265, %r4428, 0; + setp.ne.b32 %p266, %r4427, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4429, %r4423, 31; + shr.u32 %r4430, %r4429, 28; + add.s32 %r4431, %r4423, %r4430; + shr.s32 %r4432, %r4431, 4; + shr.s32 %r4433, %r4424, 31; + shr.u32 %r4434, %r4433, 28; + add.s32 %r4435, %r4424, %r4434; + shr.s32 %r4436, %r4435, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p267, %p260, %p266; + and.pred %p268, %p259, %p265; + selp.b32 %r4437, -1, 0, %p268; + selp.b32 %r4438, -1, 0, %p267; + add.s32 %r4439, %r4436, %r4438; + add.s32 %r4440, %r4432, %r4437; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p269, %r101, %r4440; + setp.eq.b32 %p270, %r101, %r4439; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p271, %p264, %p270; + and.pred %p272, %p263, %p269; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p273, %p262, %p272; + or.pred %p274, %p261, %p271; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p275, %p258, %p274; + or.pred %p276, %p257, %p273; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p277, %r4424, %r105; + setp.le.s32 %p278, %r4423, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p279, %p5, %p278; + and.pred %p280, %p5, %p277; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p281, %p7, %p260; + and.pred %p282, %p7, %p259; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4441, %r4424, %r105; + or.b32 %r4442, %r4423, %r105; + setp.gt.s32 %p283, %r4442, -1; + setp.gt.s32 %p284, %r4441, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p285, %r107, %r4440; + setp.eq.b32 %p286, %r107, %r4439; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p287, %p284, %p286; + and.pred %p288, %p283, %p285; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p289, %p282, %p288; + or.pred %p290, %p281, %p287; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p291, %p280, %p290; + or.pred %p292, %p279, %p289; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p293, %p292, %p80; + and.pred %p294, %p291, %p79; + and.pred %p295, %p276, %p80; + and.pred %p296, %p275, %p79; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4443, %r14262, %r2339; + rem.s32 %r4444, %r14261, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p297, %r4444, %r99; + setp.le.s32 %p298, %r4443, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p299, %p1, %p298; + and.pred %p300, %p1, %p297; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p301, %r4443, 0; + setp.lt.s32 %p302, %r4444, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p303, %p3, %p302; + and.pred %p304, %p3, %p301; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4445, %r4444, %r99; + or.b32 %r4446, %r4443, %r99; + setp.gt.s32 %p305, %r4446, -1; + setp.gt.s32 %p306, %r4445, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4447, %r4444, 15; + and.b32 %r4448, %r4443, 15; + setp.ne.b32 %p307, %r4448, 0; + setp.ne.b32 %p308, %r4447, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4449, %r4443, 31; + shr.u32 %r4450, %r4449, 28; + add.s32 %r4451, %r4443, %r4450; + shr.s32 %r4452, %r4451, 4; + shr.s32 %r4453, %r4444, 31; + shr.u32 %r4454, %r4453, 28; + add.s32 %r4455, %r4444, %r4454; + shr.s32 %r4456, %r4455, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p309, %p302, %p308; + and.pred %p310, %p301, %p307; + selp.b32 %r4457, -1, 0, %p310; + selp.b32 %r4458, -1, 0, %p309; + add.s32 %r4459, %r4456, %r4458; + add.s32 %r4460, %r4452, %r4457; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p311, %r101, %r4460; + setp.eq.b32 %p312, %r101, %r4459; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p313, %p306, %p312; + and.pred %p314, %p305, %p311; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p315, %p304, %p314; + or.pred %p316, %p303, %p313; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p317, %p300, %p316; + or.pred %p318, %p299, %p315; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p319, %r4444, %r105; + setp.le.s32 %p320, %r4443, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p321, %p5, %p320; + and.pred %p322, %p5, %p319; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p323, %p7, %p302; + and.pred %p324, %p7, %p301; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4461, %r4444, %r105; + or.b32 %r4462, %r4443, %r105; + setp.gt.s32 %p325, %r4462, -1; + setp.gt.s32 %p326, %r4461, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p327, %r107, %r4460; + setp.eq.b32 %p328, %r107, %r4459; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p329, %p326, %p328; + and.pred %p330, %p325, %p327; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p331, %p324, %p330; + or.pred %p332, %p323, %p329; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p333, %p322, %p332; + or.pred %p334, %p321, %p331; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p335, %p334, %p82; + and.pred %p336, %p333, %p81; + and.pred %p337, %p318, %p82; + and.pred %p338, %p317, %p81; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4463, %r14260, %r2339; + rem.s32 %r4464, %r14259, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p339, %r4464, %r99; + setp.le.s32 %p340, %r4463, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p341, %p1, %p340; + and.pred %p342, %p1, %p339; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p343, %r4463, 0; + setp.lt.s32 %p344, %r4464, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p345, %p3, %p344; + and.pred %p346, %p3, %p343; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4465, %r4464, %r99; + or.b32 %r4466, %r4463, %r99; + setp.gt.s32 %p347, %r4466, -1; + setp.gt.s32 %p348, %r4465, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4467, %r4464, 15; + and.b32 %r4468, %r4463, 15; + setp.ne.b32 %p349, %r4468, 0; + setp.ne.b32 %p350, %r4467, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4469, %r4463, 31; + shr.u32 %r4470, %r4469, 28; + add.s32 %r4471, %r4463, %r4470; + shr.s32 %r4472, %r4471, 4; + shr.s32 %r4473, %r4464, 31; + shr.u32 %r4474, %r4473, 28; + add.s32 %r4475, %r4464, %r4474; + shr.s32 %r4476, %r4475, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p351, %p344, %p350; + and.pred %p352, %p343, %p349; + selp.b32 %r4477, -1, 0, %p352; + selp.b32 %r4478, -1, 0, %p351; + add.s32 %r4479, %r4476, %r4478; + add.s32 %r4480, %r4472, %r4477; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p353, %r101, %r4480; + setp.eq.b32 %p354, %r101, %r4479; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p355, %p348, %p354; + and.pred %p356, %p347, %p353; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p357, %p346, %p356; + or.pred %p358, %p345, %p355; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p359, %p342, %p358; + or.pred %p360, %p341, %p357; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p361, %r4464, %r105; + setp.le.s32 %p362, %r4463, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p363, %p5, %p362; + and.pred %p364, %p5, %p361; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p365, %p7, %p344; + and.pred %p366, %p7, %p343; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4481, %r4464, %r105; + or.b32 %r4482, %r4463, %r105; + setp.gt.s32 %p367, %r4482, -1; + setp.gt.s32 %p368, %r4481, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p369, %r107, %r4480; + setp.eq.b32 %p370, %r107, %r4479; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p371, %p368, %p370; + and.pred %p372, %p367, %p369; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p373, %p366, %p372; + or.pred %p374, %p365, %p371; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p375, %p364, %p374; + or.pred %p376, %p363, %p373; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p377, %p376, %p84; + and.pred %p378, %p375, %p83; + and.pred %p379, %p360, %p84; + and.pred %p380, %p359, %p83; + .loc 1 762 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:762:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + rem.s32 %r4483, %r14258, %r2339; + rem.s32 %r4484, %r14257, %r2339; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p381, %r4484, %r99; + setp.le.s32 %p382, %r4483, %r99; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p383, %p1, %p382; + and.pred %p384, %p1, %p381; + .loc 1 486 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:486:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p385, %r4483, 0; + setp.lt.s32 %p386, %r4484, 0; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p387, %p3, %p386; + and.pred %p388, %p3, %p385; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4485, %r4484, %r99; + or.b32 %r4486, %r4483, %r99; + setp.gt.s32 %p389, %r4486, -1; + setp.gt.s32 %p390, %r4485, -1; + .loc 1 494 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4487, %r4484, 15; + and.b32 %r4488, %r4483, 15; + setp.ne.b32 %p391, %r4488, 0; + setp.ne.b32 %p392, %r4487, 0; + .loc 1 494 91 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:91 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.s32 %r4489, %r4483, 31; + shr.u32 %r4490, %r4489, 28; + add.s32 %r4491, %r4483, %r4490; + shr.s32 %r4492, %r4491, 4; + shr.s32 %r4493, %r4484, 31; + shr.u32 %r4494, %r4493, 28; + add.s32 %r4495, %r4484, %r4494; + shr.s32 %r4496, %r4495, 4; + .loc 1 494 119 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:494:119 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p393, %p386, %p392; + and.pred %p394, %p385, %p391; + selp.b32 %r4497, -1, 0, %p394; + selp.b32 %r4498, -1, 0, %p393; + add.s32 %r4499, %r4496, %r4498; + add.s32 %r4500, %r4492, %r4497; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p395, %r101, %r4500; + setp.eq.b32 %p396, %r101, %r4499; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p397, %p390, %p396; + and.pred %p398, %p389, %p395; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p399, %p388, %p398; + or.pred %p400, %p387, %p397; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p401, %p384, %p400; + or.pred %p402, %p383, %p399; + .loc 1 483 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:483:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.le.s32 %p403, %r4484, %r105; + setp.le.s32 %p404, %r4483, %r105; + .loc 1 484 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:484:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p405, %p5, %p404; + and.pred %p406, %p5, %p403; + .loc 1 487 22 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:487:22 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p407, %p7, %p386; + and.pred %p408, %p7, %p385; + .loc 1 489 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:489:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.b32 %r4501, %r4484, %r105; + or.b32 %r4502, %r4483, %r105; + setp.gt.s32 %p409, %r4502, -1; + setp.gt.s32 %p410, %r4501, -1; + .loc 1 495 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:495:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.eq.b32 %p411, %r107, %r4500; + setp.eq.b32 %p412, %r107, %r4499; + .loc 1 496 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:496:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p413, %p410, %p412; + and.pred %p414, %p409, %p411; + .loc 1 497 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:497:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p415, %p408, %p414; + or.pred %p416, %p407, %p413; + .loc 1 498 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:498:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + or.pred %p417, %p406, %p416; + or.pred %p418, %p405, %p415; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p419, %p418, %p86; + and.pred %p420, %p417, %p85; + and.pred %p421, %p402, %p86; + and.pred %p422, %p401, %p85; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4503, %r4311, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4504, %r4503, 0fFF800000, %p126; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4505, %r4312, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4506, %r4505, 0fFF800000, %p125; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4507, %r4313, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4508, %r4507, 0fFF800000, %p128; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4509, %r4314, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4510, %r4509, 0fFF800000, %p127; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4511, %r4315, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4512, %r4511, 0fFF800000, %p168; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4513, %r4316, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4514, %r4513, 0fFF800000, %p167; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4515, %r4317, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4516, %r4515, 0fFF800000, %p170; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4517, %r4318, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4518, %r4517, 0fFF800000, %p169; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4519, %r4319, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4520, %r4519, 0fFF800000, %p210; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4521, %r4320, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4522, %r4521, 0fFF800000, %p209; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4523, %r4321, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4524, %r4523, 0fFF800000, %p212; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4525, %r4322, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4526, %r4525, 0fFF800000, %p211; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4527, %r4323, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4528, %r4527, 0fFF800000, %p252; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4529, %r4324, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4530, %r4529, 0fFF800000, %p251; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4531, %r4325, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4532, %r4531, 0fFF800000, %p254; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4533, %r4326, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4534, %r4533, 0fFF800000, %p253; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4535, %r4327, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4536, %r4535, 0fFF800000, %p294; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4537, %r4328, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4538, %r4537, 0fFF800000, %p293; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4539, %r4329, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4540, %r4539, 0fFF800000, %p296; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4541, %r4330, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4542, %r4541, 0fFF800000, %p295; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4543, %r4331, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4544, %r4543, 0fFF800000, %p336; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4545, %r4332, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4546, %r4545, 0fFF800000, %p335; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4547, %r4333, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4548, %r4547, 0fFF800000, %p338; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4549, %r4334, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4550, %r4549, 0fFF800000, %p337; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4551, %r4335, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4552, %r4551, 0fFF800000, %p378; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4553, %r4336, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4554, %r4553, 0fFF800000, %p377; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4555, %r4337, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4556, %r4555, 0fFF800000, %p380; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4557, %r4338, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4558, %r4557, 0fFF800000, %p379; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4559, %r4339, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4560, %r4559, 0fFF800000, %p420; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4561, %r4340, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4562, %r4561, 0fFF800000, %p419; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4563, %r4341, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4564, %r4563, 0fFF800000, %p422; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4565, %r4342, 0f3FB8AA3B; + .loc 1 503 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:503:69 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.f32 %r4566, %r4565, 0fFF800000, %p421; + .loc 1 507 39 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:507:39 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4567, %r4504, %r34; + sub.f32 %r4568, %r4506, %r34; + sub.f32 %r4569, %r4508, %r35; + sub.f32 %r4570, %r4510, %r35; + sub.f32 %r4571, %r4512, %r34; + sub.f32 %r4572, %r4514, %r34; + sub.f32 %r4573, %r4516, %r35; + sub.f32 %r4574, %r4518, %r35; + sub.f32 %r4575, %r4520, %r34; + sub.f32 %r4576, %r4522, %r34; + sub.f32 %r4577, %r4524, %r35; + sub.f32 %r4578, %r4526, %r35; + sub.f32 %r4579, %r4528, %r34; + sub.f32 %r4580, %r4530, %r34; + sub.f32 %r4581, %r4532, %r35; + sub.f32 %r4582, %r4534, %r35; + sub.f32 %r4583, %r4536, %r34; + sub.f32 %r4584, %r4538, %r34; + sub.f32 %r4585, %r4540, %r35; + sub.f32 %r4586, %r4542, %r35; + sub.f32 %r4587, %r4544, %r34; + sub.f32 %r4588, %r4546, %r34; + sub.f32 %r4589, %r4548, %r35; + sub.f32 %r4590, %r4550, %r35; + sub.f32 %r4591, %r4552, %r34; + sub.f32 %r4592, %r4554, %r34; + sub.f32 %r4593, %r4556, %r35; + sub.f32 %r4594, %r4558, %r35; + sub.f32 %r4595, %r4560, %r34; + sub.f32 %r4596, %r4562, %r34; + sub.f32 %r4597, %r4564, %r35; + sub.f32 %r4598, %r4566, %r35; + .loc 1 507 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:507:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + ex2.approx.ftz.f32 %r4599, %r4567; + ex2.approx.ftz.f32 %r4600, %r4568; + ex2.approx.ftz.f32 %r4601, %r4569; + ex2.approx.ftz.f32 %r4602, %r4570; + ex2.approx.ftz.f32 %r4603, %r4571; + ex2.approx.ftz.f32 %r4604, %r4572; + ex2.approx.ftz.f32 %r4605, %r4573; + ex2.approx.ftz.f32 %r4606, %r4574; + ex2.approx.ftz.f32 %r4607, %r4575; + ex2.approx.ftz.f32 %r4608, %r4576; + ex2.approx.ftz.f32 %r4609, %r4577; + ex2.approx.ftz.f32 %r4610, %r4578; + ex2.approx.ftz.f32 %r4611, %r4579; + ex2.approx.ftz.f32 %r4612, %r4580; + ex2.approx.ftz.f32 %r4613, %r4581; + ex2.approx.ftz.f32 %r4614, %r4582; + ex2.approx.ftz.f32 %r4615, %r4583; + ex2.approx.ftz.f32 %r4616, %r4584; + ex2.approx.ftz.f32 %r4617, %r4585; + ex2.approx.ftz.f32 %r4618, %r4586; + ex2.approx.ftz.f32 %r4619, %r4587; + ex2.approx.ftz.f32 %r4620, %r4588; + ex2.approx.ftz.f32 %r4621, %r4589; + ex2.approx.ftz.f32 %r4622, %r4590; + ex2.approx.ftz.f32 %r4623, %r4591; + ex2.approx.ftz.f32 %r4624, %r4592; + ex2.approx.ftz.f32 %r4625, %r4593; + ex2.approx.ftz.f32 %r4626, %r4594; + ex2.approx.ftz.f32 %r4627, %r4595; + ex2.approx.ftz.f32 %r4628, %r4596; + ex2.approx.ftz.f32 %r4629, %r4597; + ex2.approx.ftz.f32 %r4630, %r4598; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r4631, %r2559, 49152; + add.s32 %r3680, %r4631, %r4268; + .loc 1 512 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:512:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + wgmma.fence.sync.aligned; + add.s32 %r3677, %r2559, 131072; + add.s32 %r4632, %r4272, %r3677; + bfe.u32 %r4633, %r4632, 4, 14; + cvt.u64.u32 %rd343, %r4633; + or.b64 %rd293, %rd343, 4611686293372403712; + bfe.u32 %r4634, %r3680, 4, 14; + cvt.u64.u32 %rd344, %r4634; + or.b64 %rd294, %rd344, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd293, %rd294, 0, 1, 1, 0, 0; + // end inline asm + add.s32 %r4635, %r4276, %r3677; + bfe.u32 %r4636, %r4635, 4, 14; + cvt.u64.u32 %rd345, %r4636; + or.b64 %rd295, %rd345, 4611686293372403712; + add.s32 %r4637, %r3680, 32; + bfe.u32 %r4638, %r4637, 4, 14; + cvt.u64.u32 %rd346, %r4638; + or.b64 %rd296, %rd346, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd295, %rd296, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4639, %r4281, %r3677; + bfe.u32 %r4640, %r4639, 4, 14; + cvt.u64.u32 %rd347, %r4640; + or.b64 %rd297, %rd347, 4611686293372403712; + add.s32 %r4641, %r3680, 64; + bfe.u32 %r4642, %r4641, 4, 14; + cvt.u64.u32 %rd348, %r4642; + or.b64 %rd298, %rd348, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd297, %rd298, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4643, %r4286, %r3677; + bfe.u32 %r4644, %r4643, 4, 14; + cvt.u64.u32 %rd349, %r4644; + or.b64 %rd299, %rd349, 4611686293372403712; + add.s32 %r4645, %r3680, 96; + bfe.u32 %r4646, %r4645, 4, 14; + cvt.u64.u32 %rd350, %r4646; + or.b64 %rd300, %rd350, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd299, %rd300, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4647, %r4291, %r3677; + bfe.u32 %r4648, %r4647, 4, 14; + cvt.u64.u32 %rd351, %r4648; + or.b64 %rd301, %rd351, 4611686293372403712; + add.s32 %r4649, %r3680, 8192; + bfe.u32 %r4650, %r4649, 4, 14; + cvt.u64.u32 %rd352, %r4650; + or.b64 %rd302, %rd352, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd301, %rd302, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4651, %r4296, %r3677; + bfe.u32 %r4652, %r4651, 4, 14; + cvt.u64.u32 %rd353, %r4652; + or.b64 %rd303, %rd353, 4611686293372403712; + add.s32 %r4653, %r3680, 8224; + bfe.u32 %r4654, %r4653, 4, 14; + cvt.u64.u32 %rd354, %r4654; + or.b64 %rd304, %rd354, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd303, %rd304, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4655, %r4301, %r3677; + bfe.u32 %r4656, %r4655, 4, 14; + cvt.u64.u32 %rd355, %r4656; + or.b64 %rd305, %rd355, 4611686293372403712; + add.s32 %r4657, %r3680, 8256; + bfe.u32 %r4658, %r4657, 4, 14; + cvt.u64.u32 %rd356, %r4658; + or.b64 %rd306, %rd356, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd305, %rd306, %p49, 1, 1, 0, 0; + // end inline asm + add.s32 %r4659, %r4306, %r3677; + bfe.u32 %r4660, %r4659, 4, 14; + cvt.u64.u32 %rd357, %r4660; + or.b64 %rd307, %rd357, 4611686293372403712; + add.s32 %r4661, %r3680, 8288; + bfe.u32 %r4662, %r4661, 4, 14; + cvt.u64.u32 %rd358, %r4662; + or.b64 %rd308, %rd358, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292}, %rd307, %rd308, %p49, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r3679, %r3678; + mov.b32 %r3681, %r3678; + mov.b32 %r3682, %r3678; + // begin inline asm + // wait for regs: %r3261,%r3262,%r3263,%r3264,%r3265,%r3266,%r3267,%r3268,%r3269,%r3270,%r3271,%r3272,%r3273,%r3274,%r3275,%r3276,%r3277,%r3278,%r3279,%r3280,%r3281,%r3282,%r3283,%r3284,%r3285,%r3286,%r3287,%r3288,%r3289,%r3290,%r3291,%r3292,%r3677,%r3678,%r3679,%r3680,%r3681,%r3682 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mov.b64 {%r4663, %r4664}, %rd25; + sub.f32 %r4665, %r3261, %r4663; + sub.f32 %r4666, %r3262, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4667, %r4600, %r4666; + mul.f32 %r4668, %r4599, %r4665; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs1, %r4668; + cvt.rn.bf16.f32 %rs2, %r4667; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs3, %rs2, 0x0000, %p125; + selp.b16 %rs4, %rs1, 0x0000, %p126; + mov.b32 %r3849, {%rs4, %rs3}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mov.b64 {%r4669, %r4670}, %rd24; + sub.f32 %r4671, %r3263, %r4669; + sub.f32 %r4672, %r3264, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4673, %r4602, %r4672; + mul.f32 %r4674, %r4601, %r4671; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs5, %r4674; + cvt.rn.bf16.f32 %rs6, %r4673; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs7, %rs6, 0x0000, %p127; + selp.b16 %rs8, %rs5, 0x0000, %p128; + mov.b32 %r3850, {%rs8, %rs7}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4675, %r3265, %r4663; + sub.f32 %r4676, %r3266, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4677, %r4604, %r4676; + mul.f32 %r4678, %r4603, %r4675; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs9, %r4678; + cvt.rn.bf16.f32 %rs10, %r4677; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs11, %rs10, 0x0000, %p167; + selp.b16 %rs12, %rs9, 0x0000, %p168; + mov.b32 %r3851, {%rs12, %rs11}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4679, %r3267, %r4669; + sub.f32 %r4680, %r3268, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4681, %r4606, %r4680; + mul.f32 %r4682, %r4605, %r4679; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs13, %r4682; + cvt.rn.bf16.f32 %rs14, %r4681; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs15, %rs14, 0x0000, %p169; + selp.b16 %rs16, %rs13, 0x0000, %p170; + mov.b32 %r3852, {%rs16, %rs15}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4683, %r3269, %r4663; + sub.f32 %r4684, %r3270, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4685, %r4608, %r4684; + mul.f32 %r4686, %r4607, %r4683; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs17, %r4686; + cvt.rn.bf16.f32 %rs18, %r4685; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs19, %rs18, 0x0000, %p209; + selp.b16 %rs20, %rs17, 0x0000, %p210; + mov.b32 %r3981, {%rs20, %rs19}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4687, %r3271, %r4669; + sub.f32 %r4688, %r3272, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4689, %r4610, %r4688; + mul.f32 %r4690, %r4609, %r4687; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs21, %r4690; + cvt.rn.bf16.f32 %rs22, %r4689; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs23, %rs22, 0x0000, %p211; + selp.b16 %rs24, %rs21, 0x0000, %p212; + mov.b32 %r3982, {%rs24, %rs23}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4691, %r3273, %r4663; + sub.f32 %r4692, %r3274, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4693, %r4612, %r4692; + mul.f32 %r4694, %r4611, %r4691; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs25, %r4694; + cvt.rn.bf16.f32 %rs26, %r4693; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs27, %rs26, 0x0000, %p251; + selp.b16 %rs28, %rs25, 0x0000, %p252; + mov.b32 %r3983, {%rs28, %rs27}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4695, %r3275, %r4669; + sub.f32 %r4696, %r3276, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4697, %r4614, %r4696; + mul.f32 %r4698, %r4613, %r4695; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs29, %r4698; + cvt.rn.bf16.f32 %rs30, %r4697; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs31, %rs30, 0x0000, %p253; + selp.b16 %rs32, %rs29, 0x0000, %p254; + mov.b32 %r3984, {%rs32, %rs31}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4699, %r3277, %r4663; + sub.f32 %r4700, %r3278, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4701, %r4616, %r4700; + mul.f32 %r4702, %r4615, %r4699; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs33, %r4702; + cvt.rn.bf16.f32 %rs34, %r4701; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs35, %rs34, 0x0000, %p293; + selp.b16 %rs36, %rs33, 0x0000, %p294; + mov.b32 %r4113, {%rs36, %rs35}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4703, %r3279, %r4669; + sub.f32 %r4704, %r3280, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4705, %r4618, %r4704; + mul.f32 %r4706, %r4617, %r4703; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs37, %r4706; + cvt.rn.bf16.f32 %rs38, %r4705; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs39, %rs38, 0x0000, %p295; + selp.b16 %rs40, %rs37, 0x0000, %p296; + mov.b32 %r4114, {%rs40, %rs39}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4707, %r3281, %r4663; + sub.f32 %r4708, %r3282, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4709, %r4620, %r4708; + mul.f32 %r4710, %r4619, %r4707; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs41, %r4710; + cvt.rn.bf16.f32 %rs42, %r4709; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs43, %rs42, 0x0000, %p335; + selp.b16 %rs44, %rs41, 0x0000, %p336; + mov.b32 %r4115, {%rs44, %rs43}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4711, %r3283, %r4669; + sub.f32 %r4712, %r3284, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4713, %r4622, %r4712; + mul.f32 %r4714, %r4621, %r4711; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs45, %r4714; + cvt.rn.bf16.f32 %rs46, %r4713; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs47, %rs46, 0x0000, %p337; + selp.b16 %rs48, %rs45, 0x0000, %p338; + mov.b32 %r4116, {%rs48, %rs47}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4715, %r3285, %r4663; + sub.f32 %r4716, %r3286, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4717, %r4624, %r4716; + mul.f32 %r4718, %r4623, %r4715; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs49, %r4718; + cvt.rn.bf16.f32 %rs50, %r4717; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs51, %rs50, 0x0000, %p377; + selp.b16 %rs52, %rs49, 0x0000, %p378; + mov.b32 %r4245, {%rs52, %rs51}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4719, %r3287, %r4669; + sub.f32 %r4720, %r3288, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4721, %r4626, %r4720; + mul.f32 %r4722, %r4625, %r4719; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs53, %r4722; + cvt.rn.bf16.f32 %rs54, %r4721; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs55, %rs54, 0x0000, %p379; + selp.b16 %rs56, %rs53, 0x0000, %p380; + mov.b32 %r4246, {%rs56, %rs55}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4723, %r3289, %r4663; + sub.f32 %r4724, %r3290, %r4664; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4725, %r4628, %r4724; + mul.f32 %r4726, %r4627, %r4723; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs57, %r4726; + cvt.rn.bf16.f32 %rs58, %r4725; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs59, %rs58, 0x0000, %p419; + selp.b16 %rs60, %rs57, 0x0000, %p420; + mov.b32 %r4247, {%rs60, %rs59}; + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.f32 %r4727, %r3291, %r4669; + sub.f32 %r4728, %r3292, %r4670; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.f32 %r4729, %r4630, %r4728; + mul.f32 %r4730, %r4629, %r4727; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + cvt.rn.bf16.f32 %rs61, %r4730; + cvt.rn.bf16.f32 %rs62, %r4729; + .loc 1 531 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:531:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + selp.b16 %rs63, %rs62, 0x0000, %p421; + selp.b16 %rs64, %rs61, 0x0000, %p422; + mov.b32 %r4248, {%rs64, %rs63}; + .loc 1 535 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:535:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + wgmma.fence.sync.aligned; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r3849,%r3850,%r3851,%r3852}, %rd278, %p49, 1, 1, 1; + // end inline asm + add.s32 %r4731, %r3162, 2048; + bfe.u32 %r4732, %r4731, 4, 14; + cvt.u64.u32 %rd359, %r4732; + or.b64 %rd310, %rd359, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r3981,%r3982,%r3983,%r3984}, %rd310, %p49, 1, 1, 1; + // end inline asm + add.s32 %r4733, %r3162, 4096; + bfe.u32 %r4734, %r4733, 4, 14; + cvt.u64.u32 %rd360, %r4734; + or.b64 %rd311, %rd360, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r4113,%r4114,%r4115,%r4116}, %rd311, %p49, 1, 1, 1; + // end inline asm + add.s32 %r4735, %r3162, 6144; + bfe.u32 %r4736, %r4735, 4, 14; + cvt.u64.u32 %rd361, %r4736; + or.b64 %rd312, %rd361, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r4245,%r4246,%r4247,%r4248}, %rd312, %p49, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r14271, %r14185, %r14271; + add.s32 %r14272, %r14185, %r14272; + add.s32 %r14269, %r14185, %r14269; + add.s32 %r14270, %r14185, %r14270; + add.s32 %r14267, %r14185, %r14267; + add.s32 %r14268, %r14185, %r14268; + add.s32 %r14265, %r14185, %r14265; + add.s32 %r14266, %r14185, %r14266; + add.s32 %r14263, %r14185, %r14263; + add.s32 %r14264, %r14185, %r14264; + add.s32 %r14261, %r14185, %r14261; + add.s32 %r14262, %r14185, %r14262; + add.s32 %r14259, %r14185, %r14259; + add.s32 %r14260, %r14185, %r14260; + add.s32 %r14257, %r14185, %r14257; + add.s32 %r14258, %r14185, %r14258; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r278, %r14256, 1; + .loc 1 752 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:752:33 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shr.u32 %r4737, %r278, 1; + .loc 1 753 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mad.wide.u32 %rd314, %r4737, 4, %rd216; + .loc 1 753 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + // begin inline asm + mov.u64 %rd313, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd313, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r4249, 0x0; + @%p67 ld.global.L1::evict_last.L2::cache_hint.b32 { %r4249 }, [ %rd314 + 0 ], %rd313; + // end inline asm + .loc 1 754 109 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:109 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r4738, %r4737, 1; + .loc 1 754 113 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:113 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p423, %r4738, %r2486; + .loc 1 754 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:55 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd317, %rd314, 4; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.pred %p68, %p67, %p423; + .loc 1 754 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + // begin inline asm + mov.u64 %rd316, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd316, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r4250, 0x0; + @%p68 ld.global.L1::evict_last.L2::cache_hint.b32 { %r4250 }, [ %rd317 + 0 ], %rd316; + // end inline asm + .loc 1 755 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:755:35 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + and.b32 %r4739, %r14256, 1; + .loc 1 756 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:34 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + sub.s32 %r4740, %r4250, %r4249; + .loc 1 756 48 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:48 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r4741, %r4740, 7; + .loc 1 756 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r4742, %r4741, -64; + .loc 1 757 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + xor.b32 %r4743, %r4739, 1; + .loc 1 757 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r4744, %r4739, 6; + .loc 1 757 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mad.lo.s32 %r14185, %r4742, %r4743, %r4744; + .loc 1 414 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r4745, %r14185, 10; + .loc 1 414 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + mul.wide.s32 %rd362, %r4745, 2; + add.s64 %rd1111, %rd1111, %rd362; + add.s64 %rd1110, %rd1110, %rd362; + add.s64 %rd1109, %rd1109, %rd362; + add.s64 %rd1108, %rd1108, %rd362; + .loc 1 415 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:415:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s64 %rd1107, %rd1107, %rd362; + add.s64 %rd1106, %rd1106, %rd362; + add.s64 %rd1105, %rd1105, %rd362; + add.s64 %rd1104, %rd1104, %rd362; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r14191, %r14185, %r14191; + add.s32 %r14190, %r14185, %r14190; + add.s32 %r14189, %r14185, %r14189; + add.s32 %r14188, %r14185, %r14188; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + add.s32 %r4746, %r14187, 1; + setp.gt.s32 %p424, %r4746, 2; + selp.b32 %r14187, 0, %r4746, %p424; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.lt.s32 %p425, %r14191, %r2339; + setp.lt.s32 %p426, %r14190, %r2339; + setp.lt.s32 %p427, %r14189, %r2339; + setp.lt.s32 %p428, %r14188, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + shl.b32 %r4747, %r14187, 14; + add.s32 %r4748, %r2559, %r4747; + bar.sync 0; + add.s32 %r4251, %r4748, %r59; + selp.b32 %r4749, 16, 0, %p425; + selp.b32 %r4260, %r4749, 0, %p69; + // begin inline asm + cp.async.cg.shared.global [ %r4251 + 0 ], [ %rd1111 + 0 ], 0x10, %r4260; + // end inline asm + add.s32 %r4253, %r4251, 2048; + selp.b32 %r4750, 16, 0, %p426; + selp.b32 %r4262, %r4750, 0, %p69; + // begin inline asm + cp.async.cg.shared.global [ %r4253 + 0 ], [ %rd1110 + 0 ], 0x10, %r4262; + // end inline asm + add.s32 %r4255, %r4251, 4096; + selp.b32 %r4751, 16, 0, %p427; + selp.b32 %r4264, %r4751, 0, %p69; + // begin inline asm + cp.async.cg.shared.global [ %r4255 + 0 ], [ %rd1109 + 0 ], 0x10, %r4264; + // end inline asm + add.s32 %r4257, %r4251, 6144; + selp.b32 %r4752, 16, 0, %p428; + selp.b32 %r4266, %r4752, 0, %p69; + // begin inline asm + cp.async.cg.shared.global [ %r4257 + 0 ], [ %rd1108 + 0 ], 0x10, %r4266; + // end inline asm + cp.async.commit_group; + add.s32 %r4753, %r4631, %r4747; + add.s32 %r4259, %r4753, %r59; + // begin inline asm + cp.async.cg.shared.global [ %r4259 + 0 ], [ %rd1107 + 0 ], 0x10, %r4260; + // end inline asm + add.s32 %r4261, %r4259, 2048; + // begin inline asm + cp.async.cg.shared.global [ %r4261 + 0 ], [ %rd1106 + 0 ], 0x10, %r4262; + // end inline asm + add.s32 %r4263, %r4259, 4096; + // begin inline asm + cp.async.cg.shared.global [ %r4263 + 0 ], [ %rd1105 + 0 ], 0x10, %r4264; + // end inline asm + add.s32 %r4265, %r4259, 6144; + // begin inline asm + cp.async.cg.shared.global [ %r4265 + 0 ], [ %rd1104 + 0 ], 0x10, %r4266; + // end inline asm + cp.async.commit_group; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + setp.ne.b32 %p429, %r98, %r278; + mov.b32 %r14256, %r278; + @%p429 bra $L__BB0_3; +$L__tmp26: +$L__BB0_4: // %._crit_edge + .loc 1 0 0 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0 + add.s64 %rd4, %rd182, %rd234; +$L__tmp27: + cvt.s64.s32 %rd5, %r2540; + cvt.s64.s32 %rd6, %r2541; + cvt.s64.s32 %rd7, %r2542; + cvt.s64.s32 %rd8, %r2543; + cvt.s64.s32 %rd9, %r2544; + cvt.s64.s32 %rd10, %r2545; + cvt.s64.s32 %rd11, %r2546; + cvt.s64.s32 %rd12, %r2547; +$L__tmp28: + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:207:12 ] + // begin inline asm + // wait for regs: %r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255 + wgmma.wait_group.sync.aligned 0; + // end inline asm + cp.async.wait_group 0; + bar.sync 0; +$L__tmp29: + .loc 1 214 39 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:214:39 + shl.b64 %rd381, %rd14, 2; + add.s64 %rd363, %rd188, %rd381; + .loc 1 215 31 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:215:31 + // begin inline asm + mov.u32 %r4882, 0x0; + ld.global.b32 { %r4882 }, [ %rd363 + 0 ]; + // end inline asm + .loc 1 215 45 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:215:45 + shl.b32 %r349, %r4882, 7; + .loc 1 216 62 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:216:62 + add.s64 %rd364, %rd187, %rd381; + .loc 1 216 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:216:43 + // begin inline asm + mov.u32 %r4883, 0x0; + ld.global.b32 { %r4883 }, [ %rd364 + 0 ]; + // end inline asm + .loc 1 218 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:218:33 + or.b32 %r4916, %r349, %r12; + or.b32 %r4917, %r349, %r13; + or.b32 %r4918, %r349, %r14; + or.b32 %r4919, %r349, %r15; +$L__tmp30: + .loc 1 390 37 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:37 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r4920, %r4916, 10; + shl.b32 %r4921, %r4917, 10; + shl.b32 %r4922, %r4918, 10; + shl.b32 %r4923, %r4919, 10; + .loc 1 390 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.wide.s32 %rd382, %r4920, 2; + add.s64 %rd383, %rd1, %rd382; + mul.wide.s32 %rd384, %r4921, 2; + add.s64 %rd385, %rd1, %rd384; + mul.wide.s32 %rd386, %r4922, 2; + add.s64 %rd387, %rd1, %rd386; + mul.wide.s32 %rd388, %r4923, 2; + add.s64 %rd389, %rd1, %rd388; + .loc 1 390 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:390:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b64 %rd390, %rd13, 1; + add.s64 %rd365, %rd383, %rd390; + add.s64 %rd366, %rd385, %rd390; + add.s64 %rd367, %rd387, %rd390; + add.s64 %rd368, %rd389, %rd390; + .loc 1 391 18 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:391:18 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd391, %rd2, %rd382; + add.s64 %rd392, %rd2, %rd384; + add.s64 %rd393, %rd2, %rd386; + add.s64 %rd394, %rd2, %rd388; + .loc 1 391 49 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:391:49 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd369, %rd391, %rd390; + add.s64 %rd370, %rd392, %rd390; + add.s64 %rd371, %rd393, %rd390; + add.s64 %rd372, %rd394, %rd390; + .loc 1 395 43 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:395:43 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r4924, %r4883, 1; + .loc 1 395 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:395:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + min.s32 %r351, %r4924, %r57; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p430, %r4924, 1; + setp.gt.s32 %p431, %r4924, 0; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p432, %r4916, %r2339; + setp.lt.s32 %p433, %r4917, %r2339; + setp.lt.s32 %p434, %r4918, %r2339; + setp.lt.s32 %p435, %r4919, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b32 %r4925, 16, 0, %p432; + selp.b32 %r4893, %r4925, 0, %p431; + // begin inline asm + cp.async.cg.shared.global [ %r4884 + 0 ], [ %rd365 + 0 ], 0x10, %r4893; + // end inline asm + selp.b32 %r4926, 16, 0, %p433; + selp.b32 %r4895, %r4926, 0, %p431; + // begin inline asm + cp.async.cg.shared.global [ %r4886 + 0 ], [ %rd366 + 0 ], 0x10, %r4895; + // end inline asm + selp.b32 %r4927, 16, 0, %p434; + selp.b32 %r4897, %r4927, 0, %p431; + // begin inline asm + cp.async.cg.shared.global [ %r4888 + 0 ], [ %rd367 + 0 ], 0x10, %r4897; + // end inline asm + selp.b32 %r4928, 16, 0, %p435; + selp.b32 %r4899, %r4928, 0, %p431; + // begin inline asm + cp.async.cg.shared.global [ %r4890 + 0 ], [ %rd368 + 0 ], 0x10, %r4899; + // end inline asm + cp.async.commit_group; + // begin inline asm + cp.async.cg.shared.global [ %r2495 + 0 ], [ %rd369 + 0 ], 0x10, %r4893; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2497 + 0 ], [ %rd370 + 0 ], 0x10, %r4895; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2499 + 0 ], [ %rd371 + 0 ], 0x10, %r4897; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2501 + 0 ], [ %rd372 + 0 ], 0x10, %r4899; + // end inline asm + cp.async.commit_group; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.gt.s32 %p436, %r351, 1; + .loc 1 414 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd1119, %rd365, 131072; + add.s64 %rd1118, %rd366, 131072; + add.s64 %rd1117, %rd367, 131072; + add.s64 %rd1116, %rd368, 131072; + .loc 1 415 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:415:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd1115, %rd369, 131072; + add.s64 %rd1114, %rd370, 131072; + add.s64 %rd1113, %rd371, 131072; + add.s64 %rd1112, %rd372, 131072; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + or.b32 %r14343, %r4916, 64; + or.b32 %r14342, %r4917, 64; + or.b32 %r14341, %r4918, 64; + or.b32 %r14340, %r4919, 64; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p437, %r14343, %r2339; + setp.lt.s32 %p438, %r14342, %r2339; + setp.lt.s32 %p439, %r14341, %r2339; + setp.lt.s32 %p440, %r14340, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + bar.sync 0; + selp.b32 %r4929, 16, 0, %p437; + selp.b32 %r4909, %r4929, 0, %p436; + // begin inline asm + cp.async.cg.shared.global [ %r2503 + 0 ], [ %rd1119 + 0 ], 0x10, %r4909; + // end inline asm + selp.b32 %r4930, 16, 0, %p438; + selp.b32 %r4911, %r4930, 0, %p436; + // begin inline asm + cp.async.cg.shared.global [ %r2505 + 0 ], [ %rd1118 + 0 ], 0x10, %r4911; + // end inline asm + selp.b32 %r4931, 16, 0, %p439; + selp.b32 %r4913, %r4931, 0, %p436; + // begin inline asm + cp.async.cg.shared.global [ %r2507 + 0 ], [ %rd1117 + 0 ], 0x10, %r4913; + // end inline asm + selp.b32 %r4932, 16, 0, %p440; + selp.b32 %r4915, %r4932, 0, %p436; + // begin inline asm + cp.async.cg.shared.global [ %r2509 + 0 ], [ %rd1116 + 0 ], 0x10, %r4915; + // end inline asm + cp.async.commit_group; + // begin inline asm + cp.async.cg.shared.global [ %r2511 + 0 ], [ %rd1115 + 0 ], 0x10, %r4909; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2513 + 0 ], [ %rd1114 + 0 ], 0x10, %r4911; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2515 + 0 ], [ %rd1113 + 0 ], 0x10, %r4913; + // end inline asm + // begin inline asm + cp.async.cg.shared.global [ %r2517 + 0 ], [ %rd1112 + 0 ], 0x10, %r4915; + // end inline asm + cp.async.commit_group; + .loc 1 459 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:459:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + // begin inline asm + fence.proxy.async.shared::cta; + // end inline asm + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + @%p430 bra $L__BB0_7; +$L__tmp31: +// %bb.5: // %.lr.ph1672 + .loc 1 218 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:218:33 + or.b32 %r14424, %r349, %r41; + or.b32 %r14423, %r349, %r40; + or.b32 %r14422, %r349, %r43; + or.b32 %r14417, %r349, %r45; + or.b32 %r14409, %r349, %r49; + or.b32 %r14421, %r349, %r44; + or.b32 %r14418, %r349, %r46; + or.b32 %r14410, %r349, %r50; + or.b32 %r14419, %r349, %r47; + or.b32 %r14411, %r349, %r51; + or.b32 %r14420, %r349, %r48; + or.b32 %r14412, %r349, %r52; + or.b32 %r14413, %r349, %r53; + or.b32 %r14414, %r349, %r54; + or.b32 %r14415, %r349, %r55; + or.b32 %r14416, %r349, %r56; + add.s32 %r436, %r351, -2; + add.s32 %r437, %r351, -1; +$L__tmp32: + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + max.s32 %r438, %r351, 1; + mov.b32 %r5488, 0; + mov.b32 %r14339, 1; + mov.b32 %r14338, -1; + mov.b32 %r14337, 64; + mov.b32 %r14408, %r5488; +$L__BB0_6: // %__nv_exp2f.exit1420 + // =>This Inner Loop Header: Depth=1 + setp.lt.s32 %p461, %r14408, %r436; + setp.lt.s32 %p459, %r14408, %r437; + add.s32 %r6595, %r14338, 1; + setp.gt.s32 %p462, %r6595, 2; + selp.b32 %r14338, 0, %r6595, %p462; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p463, %r14424, %r2339; + setp.lt.s32 %p464, %r14423, %r2339; + setp.lt.s32 %p465, %r14422, %r2339; + setp.lt.s32 %p466, %r14421, %r2339; + setp.lt.s32 %p467, %r14420, %r2339; + setp.lt.s32 %p468, %r14419, %r2339; + setp.lt.s32 %p469, %r14418, %r2339; + setp.lt.s32 %p470, %r14417, %r2339; + setp.lt.s32 %p471, %r14416, %r2339; + setp.lt.s32 %p472, %r14415, %r2339; + setp.lt.s32 %p473, %r14414, %r2339; + setp.lt.s32 %p474, %r14413, %r2339; + setp.lt.s32 %p475, %r14412, %r2339; + setp.lt.s32 %p476, %r14411, %r2339; + setp.lt.s32 %p477, %r14410, %r2339; + setp.lt.s32 %p478, %r14409, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cp.async.wait_group 2; + bar.sync 0; + shl.b32 %r6596, %r14338, 14; + add.s32 %r5490, %r2559, %r6596; + .loc 1 459 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:459:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shfl.sync.idx.b32 %r6598, %r10, 0, 31, -1; + wgmma.fence.sync.aligned; + shl.b32 %r6599, %r6598, 11; + and.b32 %r6600, %r6599, 8192; + add.s32 %r5449, %r2559, 98304; + add.s32 %r6601, %r6600, %r5449; + bfe.u32 %r6602, %r6601, 4, 14; + cvt.u64.u32 %rd445, %r6602; + or.b64 %rd395, %rd445, 4611686293372403712; + bfe.u32 %r6603, %r5490, 4, 14; + cvt.u64.u32 %rd446, %r6603; + or.b64 %rd396, %rd446, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd395, %rd396, 0, 1, 1, 0, 0; + // end inline asm + or.b32 %r6604, %r6600, 32; + add.s32 %r6605, %r6604, %r5449; + bfe.u32 %r6606, %r6605, 4, 14; + cvt.u64.u32 %rd447, %r6606; + or.b64 %rd397, %rd447, 4611686293372403712; + add.s32 %r6607, %r5490, 32; + bfe.u32 %r6608, %r6607, 4, 14; + cvt.u64.u32 %rd448, %r6608; + or.b64 %rd398, %rd448, 4611686293338849280; + mov.pred %p441, -1; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd397, %rd398, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6609, %r6600, 64; + add.s32 %r6610, %r6609, %r5449; + bfe.u32 %r6611, %r6610, 4, 14; + cvt.u64.u32 %rd449, %r6611; + or.b64 %rd399, %rd449, 4611686293372403712; + add.s32 %r6612, %r5490, 64; + bfe.u32 %r6613, %r6612, 4, 14; + cvt.u64.u32 %rd450, %r6613; + or.b64 %rd400, %rd450, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd399, %rd400, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6614, %r6600, 96; + add.s32 %r6615, %r6614, %r5449; + bfe.u32 %r6616, %r6615, 4, 14; + cvt.u64.u32 %rd451, %r6616; + or.b64 %rd401, %rd451, 4611686293372403712; + add.s32 %r6617, %r5490, 96; + bfe.u32 %r6618, %r6617, 4, 14; + cvt.u64.u32 %rd452, %r6618; + or.b64 %rd402, %rd452, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd401, %rd402, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6619, %r6600, 16384; + add.s32 %r6620, %r6619, %r5449; + bfe.u32 %r6621, %r6620, 4, 14; + cvt.u64.u32 %rd453, %r6621; + or.b64 %rd403, %rd453, 4611686293372403712; + add.s32 %r6622, %r5490, 8192; + bfe.u32 %r6623, %r6622, 4, 14; + cvt.u64.u32 %rd454, %r6623; + or.b64 %rd404, %rd454, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd403, %rd404, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6624, %r6600, 16416; + add.s32 %r6625, %r6624, %r5449; + bfe.u32 %r6626, %r6625, 4, 14; + cvt.u64.u32 %rd455, %r6626; + or.b64 %rd405, %rd455, 4611686293372403712; + add.s32 %r6627, %r5490, 8224; + bfe.u32 %r6628, %r6627, 4, 14; + cvt.u64.u32 %rd456, %r6628; + or.b64 %rd406, %rd456, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd405, %rd406, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6629, %r6600, 16448; + add.s32 %r6630, %r6629, %r5449; + bfe.u32 %r6631, %r6630, 4, 14; + cvt.u64.u32 %rd457, %r6631; + or.b64 %rd407, %rd457, 4611686293372403712; + add.s32 %r6632, %r5490, 8256; + bfe.u32 %r6633, %r6632, 4, 14; + cvt.u64.u32 %rd458, %r6633; + or.b64 %rd408, %rd458, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd407, %rd408, %p441, 1, 1, 0, 0; + // end inline asm + or.b32 %r6634, %r6600, 16480; + add.s32 %r6635, %r6634, %r5449; + bfe.u32 %r6636, %r6635, 4, 14; + cvt.u64.u32 %rd459, %r6636; + or.b64 %rd409, %rd459, 4611686293372403712; + add.s32 %r6637, %r5490, 8288; + bfe.u32 %r6638, %r6637, 4, 14; + cvt.u64.u32 %rd460, %r6638; + or.b64 %rd410, %rd460, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064}, %rd409, %rd410, %p441, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r5450, %r5488; + mov.b32 %r5451, %r5488; + mov.b32 %r5453, %r5488; + mov.b32 %r5454, %r5488; + mov.b32 %r5452, %r5490; + // begin inline asm + // wait for regs: %r5033,%r5034,%r5035,%r5036,%r5037,%r5038,%r5039,%r5040,%r5041,%r5042,%r5043,%r5044,%r5045,%r5046,%r5047,%r5048,%r5049,%r5050,%r5051,%r5052,%r5053,%r5054,%r5055,%r5056,%r5057,%r5058,%r5059,%r5060,%r5061,%r5062,%r5063,%r5064,%r5449,%r5450,%r5451,%r5452,%r5453,%r5454 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 461 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:461:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6639, %r5033, 0f3DB504F3; + mul.f32 %r6640, %r5034, 0f3DB504F3; + mul.f32 %r6641, %r5035, 0f3DB504F3; + mul.f32 %r6642, %r5036, 0f3DB504F3; + mul.f32 %r6643, %r5037, 0f3DB504F3; + mul.f32 %r6644, %r5038, 0f3DB504F3; + mul.f32 %r6645, %r5039, 0f3DB504F3; + mul.f32 %r6646, %r5040, 0f3DB504F3; + mul.f32 %r6647, %r5041, 0f3DB504F3; + mul.f32 %r6648, %r5042, 0f3DB504F3; + mul.f32 %r6649, %r5043, 0f3DB504F3; + mul.f32 %r6650, %r5044, 0f3DB504F3; + mul.f32 %r6651, %r5045, 0f3DB504F3; + mul.f32 %r6652, %r5046, 0f3DB504F3; + mul.f32 %r6653, %r5047, 0f3DB504F3; + mul.f32 %r6654, %r5048, 0f3DB504F3; + mul.f32 %r6655, %r5049, 0f3DB504F3; + mul.f32 %r6656, %r5050, 0f3DB504F3; + mul.f32 %r6657, %r5051, 0f3DB504F3; + mul.f32 %r6658, %r5052, 0f3DB504F3; + mul.f32 %r6659, %r5053, 0f3DB504F3; + mul.f32 %r6660, %r5054, 0f3DB504F3; + mul.f32 %r6661, %r5055, 0f3DB504F3; + mul.f32 %r6662, %r5056, 0f3DB504F3; + mul.f32 %r6663, %r5057, 0f3DB504F3; + mul.f32 %r6664, %r5058, 0f3DB504F3; + mul.f32 %r6665, %r5059, 0f3DB504F3; + mul.f32 %r6666, %r5060, 0f3DB504F3; + mul.f32 %r6667, %r5061, 0f3DB504F3; + mul.f32 %r6668, %r5062, 0f3DB504F3; + mul.f32 %r6669, %r5063, 0f3DB504F3; + mul.f32 %r6670, %r5064, 0f3DB504F3; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6671, %r6639, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6672, %r6671, 0fFF800000, %p463; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6673, %r6640, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6674, %r6673, 0fFF800000, %p464; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6675, %r6641, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6676, %r6675, 0fFF800000, %p463; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6677, %r6642, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6678, %r6677, 0fFF800000, %p464; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6679, %r6643, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6680, %r6679, 0fFF800000, %p465; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6681, %r6644, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6682, %r6681, 0fFF800000, %p466; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6683, %r6645, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6684, %r6683, 0fFF800000, %p465; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6685, %r6646, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6686, %r6685, 0fFF800000, %p466; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6687, %r6647, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6688, %r6687, 0fFF800000, %p467; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6689, %r6648, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6690, %r6689, 0fFF800000, %p468; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6691, %r6649, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6692, %r6691, 0fFF800000, %p467; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6693, %r6650, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6694, %r6693, 0fFF800000, %p468; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6695, %r6651, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6696, %r6695, 0fFF800000, %p469; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6697, %r6652, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6698, %r6697, 0fFF800000, %p470; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6699, %r6653, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6700, %r6699, 0fFF800000, %p469; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6701, %r6654, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6702, %r6701, 0fFF800000, %p470; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6703, %r6655, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6704, %r6703, 0fFF800000, %p471; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6705, %r6656, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6706, %r6705, 0fFF800000, %p472; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6707, %r6657, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6708, %r6707, 0fFF800000, %p471; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6709, %r6658, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6710, %r6709, 0fFF800000, %p472; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6711, %r6659, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6712, %r6711, 0fFF800000, %p473; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6713, %r6660, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6714, %r6713, 0fFF800000, %p474; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6715, %r6661, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6716, %r6715, 0fFF800000, %p473; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6717, %r6662, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6718, %r6717, 0fFF800000, %p474; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6719, %r6663, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6720, %r6719, 0fFF800000, %p475; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6721, %r6664, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6722, %r6721, 0fFF800000, %p476; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6723, %r6665, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6724, %r6723, 0fFF800000, %p475; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6725, %r6666, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6726, %r6725, 0fFF800000, %p476; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6727, %r6667, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6728, %r6727, 0fFF800000, %p477; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6729, %r6668, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6730, %r6729, 0fFF800000, %p478; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6731, %r6669, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6732, %r6731, 0fFF800000, %p477; + .loc 1 506 27 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:506:27 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6733, %r6670, 0f3FB8AA3B; + .loc 1 476 79 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:476:79 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.f32 %r6734, %r6733, 0fFF800000, %p478; + .loc 1 507 39 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:507:39 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + sub.f32 %r6735, %r6672, %r34; + sub.f32 %r6736, %r6674, %r34; + sub.f32 %r6737, %r6676, %r35; + sub.f32 %r6738, %r6678, %r35; + sub.f32 %r6739, %r6680, %r34; + sub.f32 %r6740, %r6682, %r34; + sub.f32 %r6741, %r6684, %r35; + sub.f32 %r6742, %r6686, %r35; + sub.f32 %r6743, %r6688, %r34; + sub.f32 %r6744, %r6690, %r34; + sub.f32 %r6745, %r6692, %r35; + sub.f32 %r6746, %r6694, %r35; + sub.f32 %r6747, %r6696, %r34; + sub.f32 %r6748, %r6698, %r34; + sub.f32 %r6749, %r6700, %r35; + sub.f32 %r6750, %r6702, %r35; + sub.f32 %r6751, %r6704, %r34; + sub.f32 %r6752, %r6706, %r34; + sub.f32 %r6753, %r6708, %r35; + sub.f32 %r6754, %r6710, %r35; + sub.f32 %r6755, %r6712, %r34; + sub.f32 %r6756, %r6714, %r34; + sub.f32 %r6757, %r6716, %r35; + sub.f32 %r6758, %r6718, %r35; + sub.f32 %r6759, %r6720, %r34; + sub.f32 %r6760, %r6722, %r34; + sub.f32 %r6761, %r6724, %r35; + sub.f32 %r6762, %r6726, %r35; + sub.f32 %r6763, %r6728, %r34; + sub.f32 %r6764, %r6730, %r34; + sub.f32 %r6765, %r6732, %r35; + sub.f32 %r6766, %r6734, %r35; + .loc 1 507 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:507:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + ex2.approx.ftz.f32 %r6767, %r6735; + ex2.approx.ftz.f32 %r6768, %r6736; + ex2.approx.ftz.f32 %r6769, %r6737; + ex2.approx.ftz.f32 %r6770, %r6738; + ex2.approx.ftz.f32 %r6771, %r6739; + ex2.approx.ftz.f32 %r6772, %r6740; + ex2.approx.ftz.f32 %r6773, %r6741; + ex2.approx.ftz.f32 %r6774, %r6742; + ex2.approx.ftz.f32 %r6775, %r6743; + ex2.approx.ftz.f32 %r6776, %r6744; + ex2.approx.ftz.f32 %r6777, %r6745; + ex2.approx.ftz.f32 %r6778, %r6746; + ex2.approx.ftz.f32 %r6779, %r6747; + ex2.approx.ftz.f32 %r6780, %r6748; + ex2.approx.ftz.f32 %r6781, %r6749; + ex2.approx.ftz.f32 %r6782, %r6750; + ex2.approx.ftz.f32 %r6783, %r6751; + ex2.approx.ftz.f32 %r6784, %r6752; + ex2.approx.ftz.f32 %r6785, %r6753; + ex2.approx.ftz.f32 %r6786, %r6754; + ex2.approx.ftz.f32 %r6787, %r6755; + ex2.approx.ftz.f32 %r6788, %r6756; + ex2.approx.ftz.f32 %r6789, %r6757; + ex2.approx.ftz.f32 %r6790, %r6758; + ex2.approx.ftz.f32 %r6791, %r6759; + ex2.approx.ftz.f32 %r6792, %r6760; + ex2.approx.ftz.f32 %r6793, %r6761; + ex2.approx.ftz.f32 %r6794, %r6762; + ex2.approx.ftz.f32 %r6795, %r6763; + ex2.approx.ftz.f32 %r6796, %r6764; + ex2.approx.ftz.f32 %r6797, %r6765; + ex2.approx.ftz.f32 %r6798, %r6766; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r6799, %r2559, 49152; + add.s32 %r6008, %r6799, %r6596; + .loc 1 512 20 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:512:20 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + wgmma.fence.sync.aligned; + add.s32 %r6005, %r2559, 131072; + add.s32 %r6800, %r6600, %r6005; + bfe.u32 %r6801, %r6800, 4, 14; + cvt.u64.u32 %rd461, %r6801; + or.b64 %rd411, %rd461, 4611686293372403712; + bfe.u32 %r6802, %r6008, 4, 14; + cvt.u64.u32 %rd462, %r6802; + or.b64 %rd412, %rd462, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd411, %rd412, 0, 1, 1, 0, 0; + // end inline asm + add.s32 %r6803, %r6604, %r6005; + bfe.u32 %r6804, %r6803, 4, 14; + cvt.u64.u32 %rd463, %r6804; + or.b64 %rd413, %rd463, 4611686293372403712; + add.s32 %r6805, %r6008, 32; + bfe.u32 %r6806, %r6805, 4, 14; + cvt.u64.u32 %rd464, %r6806; + or.b64 %rd414, %rd464, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd413, %rd414, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6807, %r6609, %r6005; + bfe.u32 %r6808, %r6807, 4, 14; + cvt.u64.u32 %rd465, %r6808; + or.b64 %rd415, %rd465, 4611686293372403712; + add.s32 %r6809, %r6008, 64; + bfe.u32 %r6810, %r6809, 4, 14; + cvt.u64.u32 %rd466, %r6810; + or.b64 %rd416, %rd466, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd415, %rd416, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6811, %r6614, %r6005; + bfe.u32 %r6812, %r6811, 4, 14; + cvt.u64.u32 %rd467, %r6812; + or.b64 %rd417, %rd467, 4611686293372403712; + add.s32 %r6813, %r6008, 96; + bfe.u32 %r6814, %r6813, 4, 14; + cvt.u64.u32 %rd468, %r6814; + or.b64 %rd418, %rd468, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd417, %rd418, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6815, %r6619, %r6005; + bfe.u32 %r6816, %r6815, 4, 14; + cvt.u64.u32 %rd469, %r6816; + or.b64 %rd419, %rd469, 4611686293372403712; + add.s32 %r6817, %r6008, 8192; + bfe.u32 %r6818, %r6817, 4, 14; + cvt.u64.u32 %rd470, %r6818; + or.b64 %rd420, %rd470, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd419, %rd420, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6819, %r6624, %r6005; + bfe.u32 %r6820, %r6819, 4, 14; + cvt.u64.u32 %rd471, %r6820; + or.b64 %rd421, %rd471, 4611686293372403712; + add.s32 %r6821, %r6008, 8224; + bfe.u32 %r6822, %r6821, 4, 14; + cvt.u64.u32 %rd472, %r6822; + or.b64 %rd422, %rd472, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd421, %rd422, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6823, %r6629, %r6005; + bfe.u32 %r6824, %r6823, 4, 14; + cvt.u64.u32 %rd473, %r6824; + or.b64 %rd423, %rd473, 4611686293372403712; + add.s32 %r6825, %r6008, 8256; + bfe.u32 %r6826, %r6825, 4, 14; + cvt.u64.u32 %rd474, %r6826; + or.b64 %rd424, %rd474, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd423, %rd424, %p441, 1, 1, 0, 0; + // end inline asm + add.s32 %r6827, %r6634, %r6005; + bfe.u32 %r6828, %r6827, 4, 14; + cvt.u64.u32 %rd475, %r6828; + or.b64 %rd425, %rd475, 4611686293372403712; + add.s32 %r6829, %r6008, 8288; + bfe.u32 %r6830, %r6829, 4, 14; + cvt.u64.u32 %rd476, %r6830; + or.b64 %rd426, %rd476, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620}, %rd425, %rd426, %p441, 1, 1, 0, 0; + // end inline asm + wgmma.commit_group.sync.aligned; + mov.b32 %r6010, %r5488; + mov.b32 %r6006, %r5488; + mov.b32 %r6007, %r5488; + mov.b32 %r6009, %r5488; + // begin inline asm + // wait for regs: %r5589,%r5590,%r5591,%r5592,%r5593,%r5594,%r5595,%r5596,%r5597,%r5598,%r5599,%r5600,%r5601,%r5602,%r5603,%r5604,%r5605,%r5606,%r5607,%r5608,%r5609,%r5610,%r5611,%r5612,%r5613,%r5614,%r5615,%r5616,%r5617,%r5618,%r5619,%r5620,%r6005,%r6006,%r6007,%r6008,%r6009,%r6010 + wgmma.wait_group.sync.aligned 0; + // end inline asm + .loc 1 513 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + sub.f32 %r6831, %r5589, %r2481; + sub.f32 %r6832, %r5590, %r2481; + sub.f32 %r6833, %r5591, %r2482; + sub.f32 %r6834, %r5592, %r2482; + sub.f32 %r6835, %r5593, %r2481; + sub.f32 %r6836, %r5594, %r2481; + sub.f32 %r6837, %r5595, %r2482; + sub.f32 %r6838, %r5596, %r2482; + sub.f32 %r6839, %r5597, %r2481; + sub.f32 %r6840, %r5598, %r2481; + sub.f32 %r6841, %r5599, %r2482; + sub.f32 %r6842, %r5600, %r2482; + sub.f32 %r6843, %r5601, %r2481; + sub.f32 %r6844, %r5602, %r2481; + sub.f32 %r6845, %r5603, %r2482; + sub.f32 %r6846, %r5604, %r2482; + sub.f32 %r6847, %r5605, %r2481; + sub.f32 %r6848, %r5606, %r2481; + sub.f32 %r6849, %r5607, %r2482; + sub.f32 %r6850, %r5608, %r2482; + sub.f32 %r6851, %r5609, %r2481; + sub.f32 %r6852, %r5610, %r2481; + sub.f32 %r6853, %r5611, %r2482; + sub.f32 %r6854, %r5612, %r2482; + sub.f32 %r6855, %r5613, %r2481; + sub.f32 %r6856, %r5614, %r2481; + sub.f32 %r6857, %r5615, %r2482; + sub.f32 %r6858, %r5616, %r2482; + sub.f32 %r6859, %r5617, %r2481; + sub.f32 %r6860, %r5618, %r2481; + sub.f32 %r6861, %r5619, %r2482; + sub.f32 %r6862, %r5620, %r2482; + .loc 1 513 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:513:14 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.f32 %r6863, %r6767, %r6831; + mul.f32 %r6864, %r6768, %r6832; + mul.f32 %r6865, %r6769, %r6833; + mul.f32 %r6866, %r6770, %r6834; + mul.f32 %r6867, %r6771, %r6835; + mul.f32 %r6868, %r6772, %r6836; + mul.f32 %r6869, %r6773, %r6837; + mul.f32 %r6870, %r6774, %r6838; + mul.f32 %r6871, %r6775, %r6839; + mul.f32 %r6872, %r6776, %r6840; + mul.f32 %r6873, %r6777, %r6841; + mul.f32 %r6874, %r6778, %r6842; + mul.f32 %r6875, %r6779, %r6843; + mul.f32 %r6876, %r6780, %r6844; + mul.f32 %r6877, %r6781, %r6845; + mul.f32 %r6878, %r6782, %r6846; + mul.f32 %r6879, %r6783, %r6847; + mul.f32 %r6880, %r6784, %r6848; + mul.f32 %r6881, %r6785, %r6849; + mul.f32 %r6882, %r6786, %r6850; + mul.f32 %r6883, %r6787, %r6851; + mul.f32 %r6884, %r6788, %r6852; + mul.f32 %r6885, %r6789, %r6853; + mul.f32 %r6886, %r6790, %r6854; + mul.f32 %r6887, %r6791, %r6855; + mul.f32 %r6888, %r6792, %r6856; + mul.f32 %r6889, %r6793, %r6857; + mul.f32 %r6890, %r6794, %r6858; + mul.f32 %r6891, %r6795, %r6859; + mul.f32 %r6892, %r6796, %r6860; + mul.f32 %r6893, %r6797, %r6861; + mul.f32 %r6894, %r6798, %r6862; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs65, %r6863; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs66, %rs65, 0x0000, %p463; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs67, %r6864; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs68, %rs67, 0x0000, %p464; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs69, %r6865; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs70, %rs69, 0x0000, %p463; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs71, %r6866; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs72, %rs71, 0x0000, %p464; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs73, %r6867; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs74, %rs73, 0x0000, %p465; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs75, %r6868; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs76, %rs75, 0x0000, %p466; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs77, %r6869; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs78, %rs77, 0x0000, %p465; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs79, %r6870; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs80, %rs79, 0x0000, %p466; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs81, %r6871; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs82, %rs81, 0x0000, %p467; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs83, %r6872; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs84, %rs83, 0x0000, %p468; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs85, %r6873; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs86, %rs85, 0x0000, %p467; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs87, %r6874; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs88, %rs87, 0x0000, %p468; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs89, %r6875; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs90, %rs89, 0x0000, %p469; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs91, %r6876; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs92, %rs91, 0x0000, %p470; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs93, %r6877; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs94, %rs93, 0x0000, %p469; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs95, %r6878; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs96, %rs95, 0x0000, %p470; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs97, %r6879; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs98, %rs97, 0x0000, %p471; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs99, %r6880; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs100, %rs99, 0x0000, %p472; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs101, %r6881; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs102, %rs101, 0x0000, %p471; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs103, %r6882; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs104, %rs103, 0x0000, %p472; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs105, %r6883; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs106, %rs105, 0x0000, %p473; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs107, %r6884; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs108, %rs107, 0x0000, %p474; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs109, %r6885; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs110, %rs109, 0x0000, %p473; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs111, %r6886; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs112, %rs111, 0x0000, %p474; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs113, %r6887; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs114, %rs113, 0x0000, %p475; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs115, %r6888; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs116, %rs115, 0x0000, %p476; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs117, %r6889; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs118, %rs117, 0x0000, %p475; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs119, %r6890; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs120, %rs119, 0x0000, %p476; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs121, %r6891; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs122, %rs121, 0x0000, %p477; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs123, %r6892; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs124, %rs123, 0x0000, %p478; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs125, %r6893; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs126, %rs125, 0x0000, %p477; + .loc 1 533 15 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:533:15 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + cvt.rn.bf16.f32 %rs127, %r6894; + .loc 1 520 71 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:520:71 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + selp.b16 %rs128, %rs127, 0x0000, %p478; + .loc 1 535 21 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:535:21 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mov.b32 %r6177, {%rs66, %rs68}; + mov.b32 %r6178, {%rs70, %rs72}; + mov.b32 %r6179, {%rs74, %rs76}; + mov.b32 %r6180, {%rs78, %rs80}; + mov.b32 %r6309, {%rs82, %rs84}; + mov.b32 %r6310, {%rs86, %rs88}; + mov.b32 %r6311, {%rs90, %rs92}; + mov.b32 %r6312, {%rs94, %rs96}; + mov.b32 %r6441, {%rs98, %rs100}; + mov.b32 %r6442, {%rs102, %rs104}; + mov.b32 %r6443, {%rs106, %rs108}; + mov.b32 %r6444, {%rs110, %rs112}; + mov.b32 %r6573, {%rs114, %rs116}; + mov.b32 %r6574, {%rs118, %rs120}; + mov.b32 %r6575, {%rs122, %rs124}; + mov.b32 %r6576, {%rs126, %rs128}; + wgmma.fence.sync.aligned; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r6177,%r6178,%r6179,%r6180}, %rd396, %p441, 1, 1, 1; + // end inline asm + add.s32 %r6895, %r5490, 2048; + bfe.u32 %r6896, %r6895, 4, 14; + cvt.u64.u32 %rd477, %r6896; + or.b64 %rd428, %rd477, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r6309,%r6310,%r6311,%r6312}, %rd428, %p441, 1, 1, 1; + // end inline asm + add.s32 %r6897, %r5490, 4096; + bfe.u32 %r6898, %r6897, 4, 14; + cvt.u64.u32 %rd478, %r6898; + or.b64 %rd429, %rd478, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r6441,%r6442,%r6443,%r6444}, %rd429, %p441, 1, 1, 1; + // end inline asm + add.s32 %r6899, %r5490, 6144; + bfe.u32 %r6900, %r6899, 4, 14; + cvt.u64.u32 %rd479, %r6900; + or.b64 %rd430, %rd479, 4611686293338849280; + // begin inline asm + wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 {%r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255}, {%r6573,%r6574,%r6575,%r6576}, %rd430, %p441, 1, 1, 1; + // end inline asm + wgmma.commit_group.sync.aligned; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r14409, %r14337, %r14409; + add.s32 %r14410, %r14337, %r14410; + add.s32 %r14411, %r14337, %r14411; + add.s32 %r14412, %r14337, %r14412; + add.s32 %r14413, %r14337, %r14413; + add.s32 %r14414, %r14337, %r14414; + add.s32 %r14415, %r14337, %r14415; + add.s32 %r14416, %r14337, %r14416; + add.s32 %r14417, %r14337, %r14417; + add.s32 %r14418, %r14337, %r14418; + add.s32 %r14419, %r14337, %r14419; + add.s32 %r14420, %r14337, %r14420; + add.s32 %r14421, %r14337, %r14421; + add.s32 %r14422, %r14337, %r14422; + add.s32 %r14423, %r14337, %r14423; + add.s32 %r14424, %r14337, %r14424; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r608, %r14408, 1; + .loc 1 752 33 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:752:33 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shr.u32 %r6901, %r608, 1; + .loc 1 753 38 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:38 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mad.wide.u32 %rd432, %r6901, 4, %rd363; + .loc 1 753 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:753:24 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + // begin inline asm + mov.u64 %rd431, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd431, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r6577, 0x0; + @%p459 ld.global.L1::evict_last.L2::cache_hint.b32 { %r6577 }, [ %rd432 + 0 ], %rd431; + // end inline asm + .loc 1 754 109 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:109 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r6902, %r6901, 1; + .loc 1 754 113 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:113 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p479, %r6902, %r4883; + .loc 1 754 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:55 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd435, %rd432, 4; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + and.pred %p460, %p459, %p479; + .loc 1 754 25 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:754:25 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + // begin inline asm + mov.u64 %rd434, 0x0; + createpolicy.fractional.L2::evict_last.b64 %rd434, 1.0; + // end inline asm + // begin inline asm + mov.u32 %r6578, 0x0; + @%p460 ld.global.L1::evict_last.L2::cache_hint.b32 { %r6578 }, [ %rd435 + 0 ], %rd434; + // end inline asm + .loc 1 755 35 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:755:35 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + and.b32 %r6903, %r14408, 1; + .loc 1 756 34 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:34 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + sub.s32 %r6904, %r6578, %r6577; + .loc 1 756 48 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:48 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r6905, %r6904, 7; + .loc 1 756 63 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:756:63 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r6906, %r6905, -64; + .loc 1 757 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:29 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + xor.b32 %r6907, %r6903, 1; + .loc 1 757 61 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:61 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r6908, %r6903, 6; + .loc 1 757 42 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:757:42 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mad.lo.s32 %r14337, %r6906, %r6907, %r6908; + .loc 1 414 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r6909, %r14337, 10; + .loc 1 414 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:414:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + mul.wide.s32 %rd480, %r6909, 2; + add.s64 %rd1119, %rd1119, %rd480; + add.s64 %rd1118, %rd1118, %rd480; + add.s64 %rd1117, %rd1117, %rd480; + add.s64 %rd1116, %rd1116, %rd480; + .loc 1 415 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:415:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s64 %rd1115, %rd1115, %rd480; + add.s64 %rd1114, %rd1114, %rd480; + add.s64 %rd1113, %rd1113, %rd480; + add.s64 %rd1112, %rd1112, %rd480; + .loc 1 417 19 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:417:19 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r14343, %r14337, %r14343; + add.s32 %r14342, %r14337, %r14342; + add.s32 %r14341, %r14337, %r14341; + add.s32 %r14340, %r14337, %r14340; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + add.s32 %r6910, %r14339, 1; + setp.gt.s32 %p480, %r6910, 2; + selp.b32 %r14339, 0, %r6910, %p480; + .loc 1 795 52 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:52 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.lt.s32 %p481, %r14343, %r2339; + setp.lt.s32 %p482, %r14342, %r2339; + setp.lt.s32 %p483, %r14341, %r2339; + setp.lt.s32 %p484, %r14340, %r2339; + .loc 1 795 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:795:23 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + shl.b32 %r6911, %r14339, 14; + add.s32 %r6912, %r2559, %r6911; + bar.sync 0; + add.s32 %r6579, %r6912, %r59; + selp.b32 %r6913, 16, 0, %p481; + selp.b32 %r6588, %r6913, 0, %p461; + // begin inline asm + cp.async.cg.shared.global [ %r6579 + 0 ], [ %rd1119 + 0 ], 0x10, %r6588; + // end inline asm + add.s32 %r6581, %r6579, 2048; + selp.b32 %r6914, 16, 0, %p482; + selp.b32 %r6590, %r6914, 0, %p461; + // begin inline asm + cp.async.cg.shared.global [ %r6581 + 0 ], [ %rd1118 + 0 ], 0x10, %r6590; + // end inline asm + add.s32 %r6583, %r6579, 4096; + selp.b32 %r6915, 16, 0, %p483; + selp.b32 %r6592, %r6915, 0, %p461; + // begin inline asm + cp.async.cg.shared.global [ %r6583 + 0 ], [ %rd1117 + 0 ], 0x10, %r6592; + // end inline asm + add.s32 %r6585, %r6579, 6144; + selp.b32 %r6916, 16, 0, %p484; + selp.b32 %r6594, %r6916, 0, %p461; + // begin inline asm + cp.async.cg.shared.global [ %r6585 + 0 ], [ %rd1116 + 0 ], 0x10, %r6594; + // end inline asm + cp.async.commit_group; + add.s32 %r6917, %r6799, %r6911; + add.s32 %r6587, %r6917, %r59; + // begin inline asm + cp.async.cg.shared.global [ %r6587 + 0 ], [ %rd1115 + 0 ], 0x10, %r6588; + // end inline asm + add.s32 %r6589, %r6587, 2048; + // begin inline asm + cp.async.cg.shared.global [ %r6589 + 0 ], [ %rd1114 + 0 ], 0x10, %r6590; + // end inline asm + add.s32 %r6591, %r6587, 4096; + // begin inline asm + cp.async.cg.shared.global [ %r6591 + 0 ], [ %rd1113 + 0 ], 0x10, %r6592; + // end inline asm + add.s32 %r6593, %r6587, 6144; + // begin inline asm + cp.async.cg.shared.global [ %r6593 + 0 ], [ %rd1112 + 0 ], 0x10, %r6594; + // end inline asm + cp.async.commit_group; + .loc 1 397 28 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:397:28 @[ cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:226:16 ] + setp.ne.b32 %p485, %r438, %r608; + mov.b32 %r14408, %r608; + @%p485 bra $L__BB0_6; +$L__BB0_7: // %._crit_edge1673 + // begin inline asm + // wait for regs: %r14192,%r14193,%r14194,%r14195,%r14196,%r14197,%r14198,%r14199,%r14200,%r14201,%r14202,%r14203,%r14204,%r14205,%r14206,%r14207,%r14208,%r14209,%r14210,%r14211,%r14212,%r14213,%r14214,%r14215,%r14216,%r14217,%r14218,%r14219,%r14220,%r14221,%r14222,%r14223,%r14224,%r14225,%r14226,%r14227,%r14228,%r14229,%r14230,%r14231,%r14232,%r14233,%r14234,%r14235,%r14236,%r14237,%r14238,%r14239,%r14240,%r14241,%r14242,%r14243,%r14244,%r14245,%r14246,%r14247,%r14248,%r14249,%r14250,%r14251,%r14252,%r14253,%r14254,%r14255 + wgmma.wait_group.sync.aligned 0; + // end inline asm + cp.async.wait_group 0; + bar.sync 0; +$L__tmp33: + .loc 1 231 24 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:231:24 + shl.b64 %rd489, %rd5, 1; + add.s64 %rd490, %rd4, %rd489; + shl.b64 %rd491, %rd6, 1; + add.s64 %rd492, %rd4, %rd491; + shl.b64 %rd493, %rd7, 1; + add.s64 %rd494, %rd4, %rd493; + shl.b64 %rd495, %rd8, 1; + add.s64 %rd496, %rd4, %rd495; + shl.b64 %rd497, %rd9, 1; + add.s64 %rd498, %rd4, %rd497; + shl.b64 %rd499, %rd10, 1; + add.s64 %rd500, %rd4, %rd499; + shl.b64 %rd501, %rd11, 1; + add.s64 %rd502, %rd4, %rd501; + shl.b64 %rd503, %rd12, 1; + add.s64 %rd504, %rd4, %rd503; + .loc 1 231 56 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:231:56 + add.s64 %rd481, %rd490, %rd390; + add.s64 %rd482, %rd492, %rd390; + add.s64 %rd483, %rd494, %rd390; + add.s64 %rd484, %rd496, %rd390; + add.s64 %rd485, %rd498, %rd390; + add.s64 %rd486, %rd500, %rd390; + add.s64 %rd487, %rd502, %rd390; + add.s64 %rd488, %rd504, %rd390; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7118, %r14192, 0f3DB504F3; + mul.f32 %r7119, %r14193, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7120, %r7119, %r7118; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7121, %r14194, 0f3DB504F3; + mul.f32 %r7122, %r14195, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7123, %r7122, %r7121; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7124, %r14196, 0f3DB504F3; + mul.f32 %r7125, %r14197, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7126, %r7125, %r7124; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7127, %r14198, 0f3DB504F3; + mul.f32 %r7128, %r14199, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7129, %r7128, %r7127; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7130, %r14200, 0f3DB504F3; + mul.f32 %r7131, %r14201, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7132, %r7131, %r7130; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7133, %r14202, 0f3DB504F3; + mul.f32 %r7134, %r14203, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7135, %r7134, %r7133; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7136, %r14204, 0f3DB504F3; + mul.f32 %r7137, %r14205, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7138, %r7137, %r7136; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7139, %r14206, 0f3DB504F3; + mul.f32 %r7140, %r14207, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7141, %r7140, %r7139; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7142, %r14208, 0f3DB504F3; + mul.f32 %r7143, %r14209, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7144, %r7143, %r7142; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7145, %r14210, 0f3DB504F3; + mul.f32 %r7146, %r14211, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7147, %r7146, %r7145; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7148, %r14212, 0f3DB504F3; + mul.f32 %r7149, %r14213, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7150, %r7149, %r7148; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7151, %r14214, 0f3DB504F3; + mul.f32 %r7152, %r14215, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7153, %r7152, %r7151; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7154, %r14216, 0f3DB504F3; + mul.f32 %r7155, %r14217, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7156, %r7155, %r7154; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7157, %r14218, 0f3DB504F3; + mul.f32 %r7158, %r14219, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7159, %r7158, %r7157; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7160, %r14220, 0f3DB504F3; + mul.f32 %r7161, %r14221, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7162, %r7161, %r7160; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7163, %r14222, 0f3DB504F3; + mul.f32 %r7164, %r14223, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7165, %r7164, %r7163; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7166, %r14224, 0f3DB504F3; + mul.f32 %r7167, %r14225, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7168, %r7167, %r7166; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7169, %r14226, 0f3DB504F3; + mul.f32 %r7170, %r14227, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7171, %r7170, %r7169; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7172, %r14228, 0f3DB504F3; + mul.f32 %r7173, %r14229, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7174, %r7173, %r7172; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7175, %r14230, 0f3DB504F3; + mul.f32 %r7176, %r14231, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7177, %r7176, %r7175; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7178, %r14232, 0f3DB504F3; + mul.f32 %r7179, %r14233, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7180, %r7179, %r7178; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7181, %r14234, 0f3DB504F3; + mul.f32 %r7182, %r14235, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7183, %r7182, %r7181; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7184, %r14236, 0f3DB504F3; + mul.f32 %r7185, %r14237, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7186, %r7185, %r7184; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7187, %r14238, 0f3DB504F3; + mul.f32 %r7188, %r14239, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7189, %r7188, %r7187; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7190, %r14240, 0f3DB504F3; + mul.f32 %r7191, %r14241, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7192, %r7191, %r7190; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7193, %r14242, 0f3DB504F3; + mul.f32 %r7194, %r14243, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7195, %r7194, %r7193; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7196, %r14244, 0f3DB504F3; + mul.f32 %r7197, %r14245, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7198, %r7197, %r7196; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7199, %r14246, 0f3DB504F3; + mul.f32 %r7200, %r14247, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7201, %r7200, %r7199; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7202, %r14248, 0f3DB504F3; + mul.f32 %r7203, %r14249, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7204, %r7203, %r7202; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7205, %r14250, 0f3DB504F3; + mul.f32 %r7206, %r14251, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7207, %r7206, %r7205; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7208, %r14252, 0f3DB504F3; + mul.f32 %r7209, %r14253, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7210, %r7209, %r7208; + .loc 1 232 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:232:14 + mul.f32 %r7211, %r14254, 0f3DB504F3; + mul.f32 %r7212, %r14255, 0f3DB504F3; + .loc 1 236 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:236:30 + cvt.rn.bf16x2.f32 %r7213, %r7212, %r7211; + shl.b32 %r7214, %r38, 13; + shl.b32 %r7215, %r9, 5; + and.b32 %r7216, %r7215, 7264; + and.b32 %r7217, %r9, 24; + shl.b32 %r7218, %r7217, 4; + shl.b32 %r7219, %r9, 2; + and.b32 %r7220, %r7219, 16; + or.b32 %r7221, %r7214, %r7220; + or.b32 %r7222, %r7216, %r7218; + or.b32 %r7223, %r7221, %r7222; + add.s32 %r7225, %r2559, %r7223; + st.shared.v4.b32 [%r7225], {%r7120, %r7126, %r7132, %r7138}; + st.shared.v4.b32 [%r7225+512], {%r7123, %r7129, %r7135, %r7141}; + xor.b32 %r7226, %r7223, 32; + add.s32 %r7227, %r2559, %r7226; + st.shared.v4.b32 [%r7227], {%r7144, %r7150, %r7156, %r7162}; + st.shared.v4.b32 [%r7227+512], {%r7147, %r7153, %r7159, %r7165}; + xor.b32 %r7228, %r7223, 64; + add.s32 %r7229, %r2559, %r7228; + st.shared.v4.b32 [%r7229], {%r7168, %r7174, %r7180, %r7186}; + st.shared.v4.b32 [%r7229+512], {%r7171, %r7177, %r7183, %r7189}; + xor.b32 %r7230, %r7223, 96; + add.s32 %r7231, %r2559, %r7230; + st.shared.v4.b32 [%r7231], {%r7192, %r7198, %r7204, %r7210}; + st.shared.v4.b32 [%r7231+512], {%r7195, %r7201, %r7207, %r7213}; + bar.sync 0; + shl.b32 %r7232, %r7217, 10; + shl.b32 %r7233, %r38, 5; + and.b32 %r7234, %r7219, 1008; + or.b32 %r7235, %r7232, %r7233; + xor.b32 %r7236, %r7235, %r7234; + add.s32 %r7050, %r2559, %r7236; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7086, %r7087, %r7088, %r7089}, [%r7050]; + // end inline asm + add.s32 %r7055, %r7050, 1024; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7090, %r7091, %r7092, %r7093}, [%r7055]; + // end inline asm + add.s32 %r7060, %r7050, 2048; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7094, %r7095, %r7096, %r7097}, [%r7060]; + // end inline asm + add.s32 %r7065, %r7050, 3072; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7098, %r7099, %r7100, %r7101}, [%r7065]; + // end inline asm + add.s32 %r7070, %r7050, 4096; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7102, %r7103, %r7104, %r7105}, [%r7070]; + // end inline asm + add.s32 %r7075, %r7050, 5120; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7106, %r7107, %r7108, %r7109}, [%r7075]; + // end inline asm + add.s32 %r7080, %r7050, 6144; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7110, %r7111, %r7112, %r7113}, [%r7080]; + // end inline asm + add.s32 %r7085, %r7050, 7168; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r7114, %r7115, %r7116, %r7117}, [%r7085]; + // end inline asm + // begin inline asm + @%p12 st.global.v4.b32 [ %rd481 + 0 ], { %r7086, %r7087, %r7088, %r7089 }; + // end inline asm + // begin inline asm + @%p13 st.global.v4.b32 [ %rd482 + 0 ], { %r7090, %r7091, %r7092, %r7093 }; + // end inline asm + // begin inline asm + @%p14 st.global.v4.b32 [ %rd483 + 0 ], { %r7094, %r7095, %r7096, %r7097 }; + // end inline asm + // begin inline asm + @%p15 st.global.v4.b32 [ %rd484 + 0 ], { %r7098, %r7099, %r7100, %r7101 }; + // end inline asm + // begin inline asm + @%p16 st.global.v4.b32 [ %rd485 + 0 ], { %r7102, %r7103, %r7104, %r7105 }; + // end inline asm + // begin inline asm + @%p17 st.global.v4.b32 [ %rd486 + 0 ], { %r7106, %r7107, %r7108, %r7109 }; + // end inline asm + // begin inline asm + @%p18 st.global.v4.b32 [ %rd487 + 0 ], { %r7110, %r7111, %r7112, %r7113 }; + // end inline asm + // begin inline asm + @%p19 st.global.v4.b32 [ %rd488 + 0 ], { %r7114, %r7115, %r7116, %r7117 }; + // end inline asm + .loc 1 139 7 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:139:7 + bra.uni $L__BB0_17; +$L__BB0_16: + .loc 1 0 7 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:0:7 + cvt.u32.u64 %r14016, %rd75; + cvt.u32.u64 %r14017, %rd74; + cvt.u32.u64 %r14018, %rd73; + cvt.u32.u64 %r14019, %rd72; + cvt.u32.u64 %r14020, %rd71; + cvt.u32.u64 %r14021, %rd70; + cvt.u32.u64 %r14022, %rd69; + cvt.u32.u64 %r14023, %rd68; + cvt.u32.u64 %r14024, %rd67; + .loc 1 323 23 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:323:23 + shl.b64 %rd1087, %rd67, 1; + add.s64 %rd1088, %rd3, %rd1087; + shl.b64 %rd1089, %rd68, 1; + add.s64 %rd1090, %rd3, %rd1089; + shl.b64 %rd1091, %rd69, 1; + add.s64 %rd1092, %rd3, %rd1091; + shl.b64 %rd1093, %rd70, 1; + add.s64 %rd1094, %rd3, %rd1093; + shl.b64 %rd1095, %rd71, 1; + add.s64 %rd1096, %rd3, %rd1095; + shl.b64 %rd1097, %rd72, 1; + add.s64 %rd1098, %rd3, %rd1097; + shl.b64 %rd1099, %rd73, 1; + add.s64 %rd1100, %rd3, %rd1099; + shl.b64 %rd1101, %rd74, 1; + add.s64 %rd1102, %rd3, %rd1101; + .loc 1 323 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:323:55 + add.s64 %rd1071, %rd1088, %rd643; + add.s64 %rd1072, %rd1090, %rd643; + add.s64 %rd1073, %rd1092, %rd643; + add.s64 %rd1074, %rd1094, %rd643; + add.s64 %rd1075, %rd1096, %rd643; + add.s64 %rd1076, %rd1098, %rd643; + add.s64 %rd1077, %rd1100, %rd643; + add.s64 %rd1078, %rd1102, %rd643; + .loc 1 332 30 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:332:30 + cvt.rn.bf16x2.f32 %r14025, %r14663, %r14662; + cvt.rn.bf16x2.f32 %r14026, %r14665, %r14664; + cvt.rn.bf16x2.f32 %r14027, %r14667, %r14666; + cvt.rn.bf16x2.f32 %r14028, %r14669, %r14668; + cvt.rn.bf16x2.f32 %r14029, %r14671, %r14670; + cvt.rn.bf16x2.f32 %r14030, %r14673, %r14672; + cvt.rn.bf16x2.f32 %r14031, %r14675, %r14674; + cvt.rn.bf16x2.f32 %r14032, %r14677, %r14676; + cvt.rn.bf16x2.f32 %r14033, %r14679, %r14678; + cvt.rn.bf16x2.f32 %r14034, %r14681, %r14680; + cvt.rn.bf16x2.f32 %r14035, %r14683, %r14682; + cvt.rn.bf16x2.f32 %r14036, %r14685, %r14684; + cvt.rn.bf16x2.f32 %r14037, %r14687, %r14686; + cvt.rn.bf16x2.f32 %r14038, %r14689, %r14688; + cvt.rn.bf16x2.f32 %r14039, %r14691, %r14690; + cvt.rn.bf16x2.f32 %r14040, %r14693, %r14692; + cvt.rn.bf16x2.f32 %r14041, %r14695, %r14694; + cvt.rn.bf16x2.f32 %r14042, %r14697, %r14696; + cvt.rn.bf16x2.f32 %r14043, %r14699, %r14698; + cvt.rn.bf16x2.f32 %r14044, %r14701, %r14700; + cvt.rn.bf16x2.f32 %r14045, %r14703, %r14702; + cvt.rn.bf16x2.f32 %r14046, %r14705, %r14704; + cvt.rn.bf16x2.f32 %r14047, %r14707, %r14706; + cvt.rn.bf16x2.f32 %r14048, %r14709, %r14708; + cvt.rn.bf16x2.f32 %r14049, %r14711, %r14710; + cvt.rn.bf16x2.f32 %r14050, %r14713, %r14712; + cvt.rn.bf16x2.f32 %r14051, %r14715, %r14714; + cvt.rn.bf16x2.f32 %r14052, %r14717, %r14716; + cvt.rn.bf16x2.f32 %r14053, %r14719, %r14718; + cvt.rn.bf16x2.f32 %r14054, %r14721, %r14720; + cvt.rn.bf16x2.f32 %r14055, %r14723, %r14722; + cvt.rn.bf16x2.f32 %r14056, %r14725, %r14724; + shl.b32 %r14057, %r691, 13; + shl.b32 %r14058, %r9, 5; + and.b32 %r14059, %r14058, 7264; + and.b32 %r14060, %r9, 24; + shl.b32 %r14061, %r14060, 4; + shl.b32 %r14062, %r9, 2; + and.b32 %r14063, %r14062, 16; + or.b32 %r14064, %r14057, %r14063; + or.b32 %r14065, %r14059, %r14061; + or.b32 %r14066, %r14064, %r14065; + add.s32 %r14068, %r7390, %r14066; + st.shared.v4.b32 [%r14068], {%r14025, %r14027, %r14029, %r14031}; + st.shared.v4.b32 [%r14068+512], {%r14026, %r14028, %r14030, %r14032}; + xor.b32 %r14069, %r14066, 32; + add.s32 %r14070, %r7390, %r14069; + st.shared.v4.b32 [%r14070], {%r14033, %r14035, %r14037, %r14039}; + st.shared.v4.b32 [%r14070+512], {%r14034, %r14036, %r14038, %r14040}; + xor.b32 %r14071, %r14066, 64; + add.s32 %r14072, %r7390, %r14071; + st.shared.v4.b32 [%r14072], {%r14041, %r14043, %r14045, %r14047}; + st.shared.v4.b32 [%r14072+512], {%r14042, %r14044, %r14046, %r14048}; + xor.b32 %r14073, %r14066, 96; + add.s32 %r14074, %r7390, %r14073; + st.shared.v4.b32 [%r14074], {%r14049, %r14051, %r14053, %r14055}; + st.shared.v4.b32 [%r14074+512], {%r14050, %r14052, %r14054, %r14056}; + bar.sync 0; + shl.b32 %r14075, %r14060, 10; + shl.b32 %r14076, %r691, 5; + and.b32 %r14077, %r14062, 1008; + or.b32 %r14078, %r14075, %r14076; + xor.b32 %r14079, %r14078, %r14077; + add.s32 %r13876, %r7390, %r14079; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13912, %r13913, %r13914, %r13915}, [%r13876]; + // end inline asm + add.s32 %r13881, %r13876, 1024; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13916, %r13917, %r13918, %r13919}, [%r13881]; + // end inline asm + add.s32 %r13886, %r13876, 2048; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13920, %r13921, %r13922, %r13923}, [%r13886]; + // end inline asm + add.s32 %r13891, %r13876, 3072; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13924, %r13925, %r13926, %r13927}, [%r13891]; + // end inline asm + add.s32 %r13896, %r13876, 4096; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13928, %r13929, %r13930, %r13931}, [%r13896]; + // end inline asm + add.s32 %r13901, %r13876, 5120; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13932, %r13933, %r13934, %r13935}, [%r13901]; + // end inline asm + add.s32 %r13906, %r13876, 6144; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13936, %r13937, %r13938, %r13939}, [%r13906]; + // end inline asm + add.s32 %r13911, %r13876, 7168; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13940, %r13941, %r13942, %r13943}, [%r13911]; + // end inline asm + // begin inline asm + @%p1253 st.global.v4.b32 [ %rd1071 + 0 ], { %r13912, %r13913, %r13914, %r13915 }; + // end inline asm + // begin inline asm + @%p1254 st.global.v4.b32 [ %rd1072 + 0 ], { %r13916, %r13917, %r13918, %r13919 }; + // end inline asm + // begin inline asm + @%p1255 st.global.v4.b32 [ %rd1073 + 0 ], { %r13920, %r13921, %r13922, %r13923 }; + // end inline asm + // begin inline asm + @%p1256 st.global.v4.b32 [ %rd1074 + 0 ], { %r13924, %r13925, %r13926, %r13927 }; + // end inline asm + // begin inline asm + @%p1257 st.global.v4.b32 [ %rd1075 + 0 ], { %r13928, %r13929, %r13930, %r13931 }; + // end inline asm + // begin inline asm + @%p1258 st.global.v4.b32 [ %rd1076 + 0 ], { %r13932, %r13933, %r13934, %r13935 }; + // end inline asm + // begin inline asm + @%p1259 st.global.v4.b32 [ %rd1077 + 0 ], { %r13936, %r13937, %r13938, %r13939 }; + // end inline asm + // begin inline asm + @%p1260 st.global.v4.b32 [ %rd1078 + 0 ], { %r13940, %r13941, %r13942, %r13943 }; + // end inline asm + .loc 1 334 14 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:334:14 + mul.f32 %r14080, %r14726, 0f3DB504F3; + mul.f32 %r14081, %r14727, 0f3DB504F3; + mul.f32 %r14082, %r14728, 0f3DB504F3; + mul.f32 %r14083, %r14729, 0f3DB504F3; + mul.f32 %r14084, %r14730, 0f3DB504F3; + mul.f32 %r14085, %r14731, 0f3DB504F3; + mul.f32 %r14086, %r14732, 0f3DB504F3; + mul.f32 %r14087, %r14733, 0f3DB504F3; + mul.f32 %r14088, %r14734, 0f3DB504F3; + mul.f32 %r14089, %r14735, 0f3DB504F3; + mul.f32 %r14090, %r14736, 0f3DB504F3; + mul.f32 %r14091, %r14737, 0f3DB504F3; + mul.f32 %r14092, %r14738, 0f3DB504F3; + mul.f32 %r14093, %r14739, 0f3DB504F3; + mul.f32 %r14094, %r14740, 0f3DB504F3; + mul.f32 %r14095, %r14741, 0f3DB504F3; + mul.f32 %r14096, %r14742, 0f3DB504F3; + mul.f32 %r14097, %r14743, 0f3DB504F3; + mul.f32 %r14098, %r14744, 0f3DB504F3; + mul.f32 %r14099, %r14745, 0f3DB504F3; + mul.f32 %r14100, %r14746, 0f3DB504F3; + mul.f32 %r14101, %r14747, 0f3DB504F3; + mul.f32 %r14102, %r14748, 0f3DB504F3; + mul.f32 %r14103, %r14749, 0f3DB504F3; + mul.f32 %r14104, %r14750, 0f3DB504F3; + mul.f32 %r14105, %r14751, 0f3DB504F3; + mul.f32 %r14106, %r14752, 0f3DB504F3; + mul.f32 %r14107, %r14753, 0f3DB504F3; + mul.f32 %r14108, %r14754, 0f3DB504F3; + mul.f32 %r14109, %r14755, 0f3DB504F3; + mul.f32 %r14110, %r14756, 0f3DB504F3; + mul.f32 %r14111, %r14757, 0f3DB504F3; + mul.f32 %r14112, %r14758, 0f3DB504F3; + mul.f32 %r14113, %r14759, 0f3DB504F3; + mul.f32 %r14114, %r14760, 0f3DB504F3; + mul.f32 %r14115, %r14761, 0f3DB504F3; + mul.f32 %r14116, %r14762, 0f3DB504F3; + mul.f32 %r14117, %r14763, 0f3DB504F3; + mul.f32 %r14118, %r14764, 0f3DB504F3; + mul.f32 %r14119, %r14765, 0f3DB504F3; + mul.f32 %r14120, %r14766, 0f3DB504F3; + mul.f32 %r14121, %r14767, 0f3DB504F3; + mul.f32 %r14122, %r14768, 0f3DB504F3; + mul.f32 %r14123, %r14769, 0f3DB504F3; + mul.f32 %r14124, %r14770, 0f3DB504F3; + mul.f32 %r14125, %r14771, 0f3DB504F3; + mul.f32 %r14126, %r14772, 0f3DB504F3; + mul.f32 %r14127, %r14773, 0f3DB504F3; + mul.f32 %r14128, %r14774, 0f3DB504F3; + mul.f32 %r14129, %r14775, 0f3DB504F3; + mul.f32 %r14130, %r14776, 0f3DB504F3; + mul.f32 %r14131, %r14777, 0f3DB504F3; + mul.f32 %r14132, %r14778, 0f3DB504F3; + mul.f32 %r14133, %r14779, 0f3DB504F3; + mul.f32 %r14134, %r14780, 0f3DB504F3; + mul.f32 %r14135, %r14781, 0f3DB504F3; + mul.f32 %r14136, %r14782, 0f3DB504F3; + mul.f32 %r14137, %r14783, 0f3DB504F3; + mul.f32 %r14138, %r14784, 0f3DB504F3; + mul.f32 %r14139, %r14785, 0f3DB504F3; + mul.f32 %r14140, %r14786, 0f3DB504F3; + mul.f32 %r14141, %r14787, 0f3DB504F3; + mul.f32 %r14142, %r14788, 0f3DB504F3; + mul.f32 %r14143, %r14789, 0f3DB504F3; + .loc 1 345 55 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:345:55 + or.b32 %r14144, %r14016, %r8; + .loc 1 345 69 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:345:69 + add.s32 %r14145, %r14024, %r14144; + add.s32 %r14146, %r14023, %r14144; + add.s32 %r14147, %r14022, %r14144; + add.s32 %r14148, %r14021, %r14144; + add.s32 %r14149, %r14020, %r14144; + add.s32 %r14150, %r14019, %r14144; + add.s32 %r14151, %r14018, %r14144; + add.s32 %r14152, %r14017, %r14144; + .loc 1 345 29 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:345:29 + mad.wide.s32 %rd1079, %r14145, 2, %rd191; + mad.wide.s32 %rd1080, %r14146, 2, %rd191; + mad.wide.s32 %rd1081, %r14147, 2, %rd191; + mad.wide.s32 %rd1082, %r14148, 2, %rd191; + mad.wide.s32 %rd1083, %r14149, 2, %rd191; + mad.wide.s32 %rd1084, %r14150, 2, %rd191; + mad.wide.s32 %rd1085, %r14151, 2, %rd191; + mad.wide.s32 %rd1086, %r14152, 2, %rd191; + .loc 1 345 99 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:345:99 + cvt.rn.bf16x2.f32 %r14153, %r14081, %r14080; + cvt.rn.bf16x2.f32 %r14154, %r14083, %r14082; + cvt.rn.bf16x2.f32 %r14155, %r14085, %r14084; + cvt.rn.bf16x2.f32 %r14156, %r14087, %r14086; + cvt.rn.bf16x2.f32 %r14157, %r14089, %r14088; + cvt.rn.bf16x2.f32 %r14158, %r14091, %r14090; + cvt.rn.bf16x2.f32 %r14159, %r14093, %r14092; + cvt.rn.bf16x2.f32 %r14160, %r14095, %r14094; + cvt.rn.bf16x2.f32 %r14161, %r14097, %r14096; + cvt.rn.bf16x2.f32 %r14162, %r14099, %r14098; + cvt.rn.bf16x2.f32 %r14163, %r14101, %r14100; + cvt.rn.bf16x2.f32 %r14164, %r14103, %r14102; + cvt.rn.bf16x2.f32 %r14165, %r14105, %r14104; + cvt.rn.bf16x2.f32 %r14166, %r14107, %r14106; + cvt.rn.bf16x2.f32 %r14167, %r14109, %r14108; + cvt.rn.bf16x2.f32 %r14168, %r14111, %r14110; + cvt.rn.bf16x2.f32 %r14169, %r14113, %r14112; + cvt.rn.bf16x2.f32 %r14170, %r14115, %r14114; + cvt.rn.bf16x2.f32 %r14171, %r14117, %r14116; + cvt.rn.bf16x2.f32 %r14172, %r14119, %r14118; + cvt.rn.bf16x2.f32 %r14173, %r14121, %r14120; + cvt.rn.bf16x2.f32 %r14174, %r14123, %r14122; + cvt.rn.bf16x2.f32 %r14175, %r14125, %r14124; + cvt.rn.bf16x2.f32 %r14176, %r14127, %r14126; + cvt.rn.bf16x2.f32 %r14177, %r14129, %r14128; + cvt.rn.bf16x2.f32 %r14178, %r14131, %r14130; + cvt.rn.bf16x2.f32 %r14179, %r14133, %r14132; + cvt.rn.bf16x2.f32 %r14180, %r14135, %r14134; + cvt.rn.bf16x2.f32 %r14181, %r14137, %r14136; + cvt.rn.bf16x2.f32 %r14182, %r14139, %r14138; + cvt.rn.bf16x2.f32 %r14183, %r14141, %r14140; + cvt.rn.bf16x2.f32 %r14184, %r14143, %r14142; + bar.sync 0; + st.shared.v4.b32 [%r14068], {%r14153, %r14155, %r14157, %r14159}; + st.shared.v4.b32 [%r14068+512], {%r14154, %r14156, %r14158, %r14160}; + st.shared.v4.b32 [%r14070], {%r14161, %r14163, %r14165, %r14167}; + st.shared.v4.b32 [%r14070+512], {%r14162, %r14164, %r14166, %r14168}; + st.shared.v4.b32 [%r14072], {%r14169, %r14171, %r14173, %r14175}; + st.shared.v4.b32 [%r14072+512], {%r14170, %r14172, %r14174, %r14176}; + st.shared.v4.b32 [%r14074], {%r14177, %r14179, %r14181, %r14183}; + st.shared.v4.b32 [%r14074+512], {%r14178, %r14180, %r14182, %r14184}; + bar.sync 0; + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13984, %r13985, %r13986, %r13987}, [%r13876]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13988, %r13989, %r13990, %r13991}, [%r13881]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13992, %r13993, %r13994, %r13995}, [%r13886]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r13996, %r13997, %r13998, %r13999}, [%r13891]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r14000, %r14001, %r14002, %r14003}, [%r13896]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r14004, %r14005, %r14006, %r14007}, [%r13901]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r14008, %r14009, %r14010, %r14011}, [%r13906]; + // end inline asm + // begin inline asm + ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%r14012, %r14013, %r14014, %r14015}, [%r13911]; + // end inline asm + // begin inline asm + @%p1253 st.global.v4.b32 [ %rd1079 + 0 ], { %r13984, %r13985, %r13986, %r13987 }; + // end inline asm + // begin inline asm + @%p1254 st.global.v4.b32 [ %rd1080 + 0 ], { %r13988, %r13989, %r13990, %r13991 }; + // end inline asm + // begin inline asm + @%p1255 st.global.v4.b32 [ %rd1081 + 0 ], { %r13992, %r13993, %r13994, %r13995 }; + // end inline asm + // begin inline asm + @%p1256 st.global.v4.b32 [ %rd1082 + 0 ], { %r13996, %r13997, %r13998, %r13999 }; + // end inline asm + // begin inline asm + @%p1257 st.global.v4.b32 [ %rd1083 + 0 ], { %r14000, %r14001, %r14002, %r14003 }; + // end inline asm + // begin inline asm + @%p1258 st.global.v4.b32 [ %rd1084 + 0 ], { %r14004, %r14005, %r14006, %r14007 }; + // end inline asm + // begin inline asm + @%p1259 st.global.v4.b32 [ %rd1085 + 0 ], { %r14008, %r14009, %r14010, %r14011 }; + // end inline asm + // begin inline asm + @%p1260 st.global.v4.b32 [ %rd1086 + 0 ], { %r14012, %r14013, %r14014, %r14015 }; + // end inline asm +$L__BB0_17: + .loc 1 139 4 // cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py:139:4 + ret; +$L__tmp34: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py" + .file 2 "/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 5 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 5 // DW_FORM_data2 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 6 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 454 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x1bf DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 108 +.b8 55 +.b8 50 +.b8 113 +.b8 55 +.b8 107 +.b8 97 +.b8 50 +.b8 121 +.b8 99 +.b8 103 +.b8 51 +.b8 107 +.b8 104 +.b8 119 +.b8 104 +.b8 102 +.b8 113 +.b8 55 +.b8 118 +.b8 117 +.b8 103 +.b8 106 +.b8 104 +.b8 122 +.b8 100 +.b8 107 +.b8 97 +.b8 108 +.b8 118 +.b8 122 +.b8 54 +.b8 119 +.b8 104 +.b8 110 +.b8 113 +.b8 104 +.b8 119 +.b8 122 +.b8 112 +.b8 109 +.b8 118 +.b8 118 +.b8 117 +.b8 120 +.b8 115 +.b8 110 +.b8 99 +.b8 100 +.b8 99 +.b8 55 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 106 +.b8 117 +.b8 110 +.b8 113 +.b8 117 +.b8 97 +.b8 110 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 108 +.b8 55 +.b8 0 +.b8 2 // Abbrev [2] 0x8f:0x19 DW_TAG_subprogram +.b8 116 // DW_AT_name +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 95 +.b8 116 +.b8 101 +.b8 109 +.b8 95 +.b8 102 +.b8 117 +.b8 115 +.b8 101 +.b8 100 +.b8 95 +.b8 109 +.b8 117 +.b8 108 +.b8 95 +.b8 49 +.b8 0 +.b8 1 // DW_AT_inline +.b8 3 // Abbrev [3] 0xa8:0x121 DW_TAG_subprogram +.b64 $L__func_begin0 // DW_AT_low_pc +.b64 $L__func_end0 // DW_AT_high_pc +.b32 143 // DW_AT_abstract_origin +.b8 4 // Abbrev [4] 0xbd:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp1 // DW_AT_low_pc +.b64 $L__tmp2 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 112 // DW_AT_call_line +.b8 36 // DW_AT_call_column +.b8 5 // Abbrev [5] 0xd5:0x19 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp3 // DW_AT_low_pc +.b64 $L__tmp4 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 0 // DW_AT_call_line +.b8 1 +.b8 107 // DW_AT_call_column +.b8 5 // Abbrev [5] 0xee:0x19 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp4 // DW_AT_low_pc +.b64 $L__tmp5 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 1 // DW_AT_call_line +.b8 1 +.b8 107 // DW_AT_call_column +.b8 5 // Abbrev [5] 0x107:0x19 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp6 // DW_AT_low_pc +.b64 $L__tmp16 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 42 // DW_AT_call_line +.b8 1 +.b8 16 // DW_AT_call_column +.b8 5 // Abbrev [5] 0x120:0x19 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp10 // DW_AT_low_pc +.b64 $L__tmp17 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 62 // DW_AT_call_line +.b8 1 +.b8 20 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x139:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp18 // DW_AT_low_pc +.b64 $L__tmp19 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 113 // DW_AT_call_line +.b8 34 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x151:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp20 // DW_AT_low_pc +.b64 $L__tmp21 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 178 // DW_AT_call_line +.b8 107 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x169:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp21 // DW_AT_low_pc +.b64 $L__tmp22 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 179 // DW_AT_call_line +.b8 111 // DW_AT_call_column +.b8 4 // Abbrev [4] 0x181:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp23 // DW_AT_low_pc +.b64 $L__tmp29 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 207 // DW_AT_call_line +.b8 12 // DW_AT_call_column +.b8 6 // Abbrev [6] 0x199:0x17 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp27 // DW_AT_low_pc +.b64 $L__tmp28 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 0 // DW_AT_call_line +.b8 4 // Abbrev [4] 0x1b0:0x18 DW_TAG_inlined_subroutine +.b32 143 // DW_AT_abstract_origin +.b64 $L__tmp30 // DW_AT_low_pc +.b64 $L__tmp33 // DW_AT_high_pc +.b8 1 // DW_AT_call_file +.b8 226 // DW_AT_call_line +.b8 16 // DW_AT_call_column +.b8 0 // End Of Children Mark +.b8 0 // End Of Children Mark + } + .section .debug_macinfo { } diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.source b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.source new file mode 100644 index 0000000000000000000000000000000000000000..69eb23c5a67350e489b97a53900454628acb5423 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.source @@ -0,0 +1,2351 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":18:0) +#loc232 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":32:0) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":776:0) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":348:0) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":423:0) +#loc361 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":761:0) +#loc365 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":745:0) +#loc386 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":541:0) +#loc416 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":616:0) +#loc497 = loc("arg_Q"(#loc)) +#loc498 = loc("arg_K"(#loc)) +#loc499 = loc("arg_V"(#loc)) +#loc500 = loc("arg_LSE"(#loc)) +#loc501 = loc("arg_DELTA"(#loc)) +#loc502 = loc("arg_DO"(#loc)) +#loc503 = loc("arg_DQ"(#loc)) +#loc504 = loc("arg_DV"(#loc)) +#loc505 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc506 = loc("arg_KV_IDX"(#loc)) +#loc507 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc508 = loc("arg_Q_IDX"(#loc)) +#loc509 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc510 = loc("arg_FULL_KV_IDX"(#loc)) +#loc511 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc512 = loc("arg_FULL_Q_IDX"(#loc)) +#loc513 = loc("out_ptr0"(#loc)) +#loc514 = loc("ks0"(#loc)) +#loc515 = loc("ks1"(#loc)) +#loc702 = loc("x"(#loc232)) +#loc703 = loc("ptr"(#loc242)) +#loc704 = loc("offs_m"(#loc242)) +#loc705 = loc("offs_n"(#loc242)) +#loc706 = loc("stride_m"(#loc242)) +#loc707 = loc("stride_n"(#loc242)) +#loc708 = loc("M_LEN"(#loc242)) +#loc715 = loc("arg_Q"(#loc254)) +#loc716 = loc("arg_K"(#loc254)) +#loc717 = loc("arg_V"(#loc254)) +#loc718 = loc("arg_LSE"(#loc254)) +#loc719 = loc("arg_DELTA"(#loc254)) +#loc720 = loc("arg_DO"(#loc254)) +#loc721 = loc("arg_DQ"(#loc254)) +#loc722 = loc("arg_DV"(#loc254)) +#loc723 = loc("arg_KV_NUM_BLKS"(#loc254)) +#loc724 = loc("arg_KV_IDX"(#loc254)) +#loc725 = loc("arg_Q_NUM_BLKS"(#loc254)) +#loc726 = loc("arg_Q_IDX"(#loc254)) +#loc727 = loc("arg_FULL_KV_NUM_BLKS"(#loc254)) +#loc728 = loc("arg_FULL_KV_IDX"(#loc254)) +#loc729 = loc("arg_FULL_Q_NUM_BLKS"(#loc254)) +#loc730 = loc("arg_FULL_Q_IDX"(#loc254)) +#loc731 = loc("out_ptr0"(#loc254)) +#loc732 = loc("ks0"(#loc254)) +#loc733 = loc("ks1"(#loc254)) +#loc734 = loc("K"(#loc254)) +#loc735 = loc("V"(#loc254)) +#loc736 = loc("dq"(#loc254)) +#loc737 = loc("q"(#loc254)) +#loc738 = loc("do"(#loc254)) +#loc739 = loc("Di"(#loc254)) +#loc740 = loc("lse"(#loc254)) +#loc741 = loc("off_z"(#loc254)) +#loc742 = loc("off_hq"(#loc254)) +#loc743 = loc("offs_m2"(#loc254)) +#loc744 = loc("offs_n2"(#loc254)) +#loc745 = loc("stride_kn"(#loc254)) +#loc746 = loc("stride_kd"(#loc254)) +#loc747 = loc("stride_vn"(#loc254)) +#loc748 = loc("stride_vd"(#loc254)) +#loc749 = loc("kv_indices"(#loc254)) +#loc750 = loc("sparse_kv_num_blocks"(#loc254)) +#loc777 = loc("arg_Q"(#loc284)) +#loc778 = loc("arg_K"(#loc284)) +#loc779 = loc("arg_V"(#loc284)) +#loc780 = loc("arg_LSE"(#loc284)) +#loc781 = loc("arg_DELTA"(#loc284)) +#loc782 = loc("arg_DO"(#loc284)) +#loc783 = loc("arg_DQ"(#loc284)) +#loc784 = loc("arg_DV"(#loc284)) +#loc785 = loc("arg_KV_NUM_BLKS"(#loc284)) +#loc786 = loc("arg_KV_IDX"(#loc284)) +#loc787 = loc("arg_Q_NUM_BLKS"(#loc284)) +#loc788 = loc("arg_Q_IDX"(#loc284)) +#loc789 = loc("arg_FULL_KV_NUM_BLKS"(#loc284)) +#loc790 = loc("arg_FULL_KV_IDX"(#loc284)) +#loc791 = loc("arg_FULL_Q_NUM_BLKS"(#loc284)) +#loc792 = loc("arg_FULL_Q_IDX"(#loc284)) +#loc793 = loc("out_ptr0"(#loc284)) +#loc794 = loc("ks0"(#loc284)) +#loc795 = loc("ks1"(#loc284)) +#loc796 = loc("dq"(#loc284)) +#loc797 = loc("q"(#loc284)) +#loc798 = loc("kT_ptrs"(#loc284)) +#loc799 = loc("vT_ptrs"(#loc284)) +#loc800 = loc("do"(#loc284)) +#loc801 = loc("Di"(#loc284)) +#loc802 = loc("lse"(#loc284)) +#loc803 = loc("Q_LEN"(#loc284)) +#loc804 = loc("KV_LEN"(#loc284)) +#loc805 = loc("off_z"(#loc284)) +#loc806 = loc("off_hq"(#loc284)) +#loc807 = loc("offs_m2"(#loc284)) +#loc808 = loc("offs_n2"(#loc284)) +#loc809 = loc("offs_k"(#loc284)) +#loc810 = loc("offs_v"(#loc284)) +#loc811 = loc("stride_kn"(#loc284)) +#loc812 = loc("stride_kd"(#loc284)) +#loc813 = loc("stride_vn"(#loc284)) +#loc814 = loc("stride_vd"(#loc284)) +#loc815 = loc("kv_indices"(#loc284)) +#loc816 = loc("sparse_kv_num_blocks"(#loc284)) +#loc887 = loc("N_LEN"(#loc242)) +#loc888 = loc("indices"(#loc361)) +#loc889 = loc("max_len"(#loc361)) +#loc890 = loc("loop_iter"(#loc365)) +#loc891 = loc("col_indices"(#loc365)) +#loc892 = loc("total_blocks"(#loc365)) +#loc911 = loc("arg_Q"(#loc386)) +#loc912 = loc("arg_K"(#loc386)) +#loc913 = loc("arg_V"(#loc386)) +#loc914 = loc("arg_LSE"(#loc386)) +#loc915 = loc("arg_DELTA"(#loc386)) +#loc916 = loc("arg_DO"(#loc386)) +#loc917 = loc("arg_DQ"(#loc386)) +#loc918 = loc("arg_DV"(#loc386)) +#loc919 = loc("arg_KV_NUM_BLKS"(#loc386)) +#loc920 = loc("arg_KV_IDX"(#loc386)) +#loc921 = loc("arg_Q_NUM_BLKS"(#loc386)) +#loc922 = loc("arg_Q_IDX"(#loc386)) +#loc923 = loc("arg_FULL_KV_NUM_BLKS"(#loc386)) +#loc924 = loc("arg_FULL_KV_IDX"(#loc386)) +#loc925 = loc("arg_FULL_Q_NUM_BLKS"(#loc386)) +#loc926 = loc("arg_FULL_Q_IDX"(#loc386)) +#loc927 = loc("out_ptr0"(#loc386)) +#loc928 = loc("ks0"(#loc386)) +#loc929 = loc("ks1"(#loc386)) +#loc930 = loc("Q"(#loc386)) +#loc931 = loc("DO"(#loc386)) +#loc932 = loc("DELTA"(#loc386)) +#loc933 = loc("LSE"(#loc386)) +#loc934 = loc("dk"(#loc386)) +#loc935 = loc("dv"(#loc386)) +#loc936 = loc("k"(#loc386)) +#loc937 = loc("v"(#loc386)) +#loc938 = loc("off_z"(#loc386)) +#loc939 = loc("off_hq"(#loc386)) +#loc940 = loc("offs_n1"(#loc386)) +#loc941 = loc("offs_m1"(#loc386)) +#loc942 = loc("stride_qm"(#loc386)) +#loc943 = loc("stride_qd"(#loc386)) +#loc944 = loc("stride_dom"(#loc386)) +#loc945 = loc("stride_dod"(#loc386)) +#loc946 = loc("q_indices"(#loc386)) +#loc947 = loc("sparse_q_num_blocks"(#loc386)) +#loc973 = loc("arg_Q"(#loc416)) +#loc974 = loc("arg_K"(#loc416)) +#loc975 = loc("arg_V"(#loc416)) +#loc976 = loc("arg_LSE"(#loc416)) +#loc977 = loc("arg_DELTA"(#loc416)) +#loc978 = loc("arg_DO"(#loc416)) +#loc979 = loc("arg_DQ"(#loc416)) +#loc980 = loc("arg_DV"(#loc416)) +#loc981 = loc("arg_KV_NUM_BLKS"(#loc416)) +#loc982 = loc("arg_KV_IDX"(#loc416)) +#loc983 = loc("arg_Q_NUM_BLKS"(#loc416)) +#loc984 = loc("arg_Q_IDX"(#loc416)) +#loc985 = loc("arg_FULL_KV_NUM_BLKS"(#loc416)) +#loc986 = loc("arg_FULL_KV_IDX"(#loc416)) +#loc987 = loc("arg_FULL_Q_NUM_BLKS"(#loc416)) +#loc988 = loc("arg_FULL_Q_IDX"(#loc416)) +#loc989 = loc("out_ptr0"(#loc416)) +#loc990 = loc("ks0"(#loc416)) +#loc991 = loc("ks1"(#loc416)) +#loc992 = loc("dk"(#loc416)) +#loc993 = loc("dv"(#loc416)) +#loc994 = loc("qT_ptrs"(#loc416)) +#loc995 = loc("k"(#loc416)) +#loc996 = loc("v"(#loc416)) +#loc997 = loc("do_ptrs"(#loc416)) +#loc998 = loc("DELTA"(#loc416)) +#loc999 = loc("LSE"(#loc416)) +#loc1000 = loc("Q_LEN"(#loc416)) +#loc1001 = loc("KV_LEN"(#loc416)) +#loc1002 = loc("off_z"(#loc416)) +#loc1003 = loc("off_hq"(#loc416)) +#loc1004 = loc("offs_n1"(#loc416)) +#loc1005 = loc("offs_m1"(#loc416)) +#loc1006 = loc("offs_k"(#loc416)) +#loc1007 = loc("offs_v"(#loc416)) +#loc1008 = loc("stride_qm"(#loc416)) +#loc1009 = loc("stride_qd"(#loc416)) +#loc1010 = loc("stride_dom"(#loc416)) +#loc1011 = loc("stride_dod"(#loc416)) +#loc1012 = loc("q_indices"(#loc416)) +#loc1013 = loc("sparse_q_num_blocks"(#loc416)) +module { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc))) attributes {noinline = false} { + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %c4096_i32_0 = arith.constant 4096 : i32 loc(#loc1) + %0 = arith.muli %c4096_i32_0, %ks0 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc2) + %c4096_i32_1 = arith.constant 4096 : i32 loc(#loc2) + %c1_i32 = arith.constant 1 : i32 loc(#loc2) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc3) + %c1024_i32_2 = arith.constant 1024 : i32 loc(#loc3) + %1 = arith.muli %c1024_i32_2, %ks1 : i32 loc(#loc3) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc4) + %c1024_i32_4 = arith.constant 1024 : i32 loc(#loc4) + %c1_i32_5 = arith.constant 1 : i32 loc(#loc4) + %c1024_i32_6 = arith.constant 1024 : i32 loc(#loc5) + %c1024_i32_7 = arith.constant 1024 : i32 loc(#loc5) + %2 = arith.muli %c1024_i32_7, %ks1 : i32 loc(#loc5) + %c128_i32_8 = arith.constant 128 : i32 loc(#loc6) + %c1024_i32_9 = arith.constant 1024 : i32 loc(#loc6) + %c1_i32_10 = arith.constant 1 : i32 loc(#loc6) + %c1_i32_11 = arith.constant 1 : i32 loc(#loc7) + %3 = arith.cmpi sge, %c1_i32_11, %ks0 : i32 loc(#loc7) + %c1_i32_12 = arith.constant 1 : i32 loc(#loc8) + %c1_i32_13 = arith.constant 1 : i32 loc(#loc8) + %4 = arith.extui %3 : i1 to i32 loc(#loc8) + %5 = arith.muli %c1_i32_13, %4 : i32 loc(#loc8) + %c1_i32_14 = arith.constant 1 : i32 loc(#loc9) + %6 = arith.cmpi sgt, %ks0, %c1_i32_14 : i32 loc(#loc9) + %7 = arith.extui %6 : i1 to i32 loc(#loc10) + %8 = arith.muli %ks0, %7 : i32 loc(#loc10) + %9 = arith.addi %5, %8 : i32 loc(#loc11) + %c4096_i32_15 = arith.constant 4096 : i32 loc(#loc12) + %c4096_i32_16 = arith.constant 4096 : i32 loc(#loc12) + %10 = arith.muli %c4096_i32_16, %9 : i32 loc(#loc12) + %c1_i32_17 = arith.constant 1 : i32 loc(#loc13) + %11 = arith.cmpi sge, %c1_i32_17, %ks0 : i32 loc(#loc13) + %c1_i32_18 = arith.constant 1 : i32 loc(#loc14) + %c1_i32_19 = arith.constant 1 : i32 loc(#loc14) + %12 = arith.extui %11 : i1 to i32 loc(#loc14) + %13 = arith.muli %c1_i32_19, %12 : i32 loc(#loc14) + %c1_i32_20 = arith.constant 1 : i32 loc(#loc15) + %14 = arith.cmpi sgt, %ks0, %c1_i32_20 : i32 loc(#loc15) + %15 = arith.extui %14 : i1 to i32 loc(#loc16) + %16 = arith.muli %ks0, %15 : i32 loc(#loc16) + %17 = arith.addi %13, %16 : i32 loc(#loc17) + %c128_i32_21 = arith.constant 128 : i32 loc(#loc18) + %c128_i32_22 = arith.constant 128 : i32 loc(#loc18) + %18 = arith.muli %c128_i32_22, %17 : i32 loc(#loc18) + %c128_i32_23 = arith.constant 128 : i32 loc(#loc19) + %c1_i32_24 = arith.constant 1 : i32 loc(#loc19) + %c4096_i32_25 = arith.constant 4096 : i32 loc(#loc20) + %c4096_i32_26 = arith.constant 4096 : i32 loc(#loc20) + %19 = arith.muli %c4096_i32_26, %ks0 : i32 loc(#loc20) + %c128_i32_27 = arith.constant 128 : i32 loc(#loc21) + %c4096_i32_28 = arith.constant 4096 : i32 loc(#loc21) + %c1_i32_29 = arith.constant 1 : i32 loc(#loc21) + %c1024_i32_30 = arith.constant 1024 : i32 loc(#loc22) + %c1024_i32_31 = arith.constant 1024 : i32 loc(#loc22) + %20 = arith.muli %c1024_i32_31, %ks1 : i32 loc(#loc22) + %c128_i32_32 = arith.constant 128 : i32 loc(#loc23) + %c1024_i32_33 = arith.constant 1024 : i32 loc(#loc23) + %c1_i32_34 = arith.constant 1 : i32 loc(#loc23) + %ZQ = arith.constant 1 : i32 loc(#loc516) + %HQ = arith.constant 32 : i32 loc(#loc517) + %HKV = arith.constant 8 : i32 loc(#loc518) + %ZKV = arith.constant 1 : i32 loc(#loc519) + %pid = tt.get_program_id x : i32 loc(#loc520) + %NUM_KV_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%ks1) : (i32) -> i32 loc(#loc521) + %NUM_Q_BLOCKS = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%ks0) : (i32) -> i32 loc(#loc522) + %off_zq = tt.get_program_id y : i32 loc(#loc523) + %off_hkv = tt.get_program_id z : i32 loc(#loc524) + %off_zkv = arith.remsi %off_zq, %ZKV : i32 loc(#loc525) + %SPARSE_Z = arith.constant 1 : i32 loc(#loc526) + %SPARSE_HQ = arith.constant 1 : i32 loc(#loc527) + %sparse_idx_z = arith.remsi %off_zq, %SPARSE_Z : i32 loc(#loc528) + %k_adj = arith.muli %c128_i32_3, %off_hkv : i32 loc(#loc529) + %k_adj_35 = arith.muli %1, %off_zkv : i32 loc(#loc530) + %k_adj_36 = arith.addi %k_adj, %k_adj_35 : i32 loc(#loc531) + %k_adj_37 = arith.extsi %k_adj_36 : i32 to i64 loc(#loc532) + %v_adj = arith.muli %c128_i32_8, %off_hkv : i32 loc(#loc533) + %v_adj_38 = arith.muli %2, %off_zkv : i32 loc(#loc534) + %v_adj_39 = arith.addi %v_adj, %v_adj_38 : i32 loc(#loc535) + %v_adj_40 = arith.extsi %v_adj_39 : i32 to i64 loc(#loc536) + %dv_adj = arith.muli %c128_i32_32, %off_hkv : i32 loc(#loc537) + %dv_adj_41 = arith.muli %20, %off_zq : i32 loc(#loc538) + %dv_adj_42 = arith.addi %dv_adj, %dv_adj_41 : i32 loc(#loc539) + %dv_adj_43 = arith.extsi %dv_adj_42 : i32 to i64 loc(#loc540) + %K = tt.addptr %arg_K, %k_adj_37 : !tt.ptr, i64 loc(#loc541) + %V = tt.addptr %arg_V, %v_adj_40 : !tt.ptr, i64 loc(#loc542) + %DV = tt.addptr %arg_DV, %dv_adj_43 : !tt.ptr, i64 loc(#loc543) + %RCP_LN2 = arith.constant 1.44269502 : f32 loc(#loc544) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc545) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc546) + %21 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS : i32 loc(#loc55) + %22:2 = scf.if %21 -> (i32, i32) { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS : i32 loc(#loc547) + %SPARSE_Q_MULTIPLE = arith.constant 1 : i32 loc(#loc1092) + %SPARSE_KV_MULTIPLE = arith.constant 2 : i32 loc(#loc1093) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc550) + %off_hq2_44 = arith.constant 4 : i32 loc(#loc551) + %off_hq2_45 = arith.constant 4 : i32 loc(#loc551) + %off_hq2_46 = arith.muli %off_hkv, %off_hq2_45 : i32 loc(#loc551) + %off_hq2_47 = arith.addi %off_hq2, %off_hq2_46 : i32 loc(#loc552) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS : i32 loc(#loc553) + %off_pid_mask = arith.divsi %start_m2_block, %SPARSE_Q_MULTIPLE : i32 loc(#loc554) + %stride_kv_num_blks_h = arith.constant 1 : i32 loc(#loc555) + %stride_kv_idx_h = arith.constant 1 : i32 loc(#loc556) + %stride_kv_idx_m = arith.constant 1 : i32 loc(#loc557) + %sparse_idx_hq2 = arith.remsi %off_hq2_47, %SPARSE_HQ : i32 loc(#loc558) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc559) + %sparse_hz_offset_48 = arith.addi %sparse_hz_offset, %sparse_idx_hq2 : i32 loc(#loc560) + %sparse_kv_num_blks_offset = arith.muli %sparse_hz_offset_48, %stride_kv_num_blks_h : i32 loc(#loc561) + %sparse_kv_num_blks_offset_49 = arith.addi %sparse_kv_num_blks_offset, %off_pid_mask : i32 loc(#loc562) + %sparse_kv_idx_offset = arith.muli %sparse_hz_offset_48, %stride_kv_idx_h : i32 loc(#loc563) + %sparse_kv_idx_offset_50 = arith.muli %off_pid_mask, %stride_kv_idx_m : i32 loc(#loc564) + %sparse_kv_idx_offset_51 = arith.addi %sparse_kv_idx_offset, %sparse_kv_idx_offset_50 : i32 loc(#loc565) + %q_adj2 = arith.muli %c128_i32, %off_hq2_47 : i32 loc(#loc566) + %q_adj2_52 = arith.muli %0, %off_zq : i32 loc(#loc567) + %q_adj2_53 = arith.addi %q_adj2, %q_adj2_52 : i32 loc(#loc568) + %q_adj2_54 = arith.extsi %q_adj2_53 : i32 to i64 loc(#loc569) + %do_adj2 = arith.muli %18, %off_hq2_47 : i32 loc(#loc570) + %do_adj2_55 = arith.muli %10, %off_zq : i32 loc(#loc571) + %do_adj2_56 = arith.addi %do_adj2, %do_adj2_55 : i32 loc(#loc572) + %do_adj2_57 = arith.extsi %do_adj2_56 : i32 to i64 loc(#loc573) + %dq_adj2 = arith.muli %c128_i32_27, %off_hq2_47 : i32 loc(#loc574) + %dq_adj2_58 = arith.muli %19, %off_zq : i32 loc(#loc575) + %dq_adj2_59 = arith.addi %dq_adj2, %dq_adj2_58 : i32 loc(#loc576) + %dq_adj2_60 = arith.extsi %dq_adj2_59 : i32 to i64 loc(#loc577) + %off_chz2 = arith.muli %off_zq, %HQ : i32 loc(#loc578) + %off_chz2_61 = arith.addi %off_chz2, %off_hq2_47 : i32 loc(#loc579) + %off_chz2_62 = arith.muli %off_chz2_61, %ks0 : i32 loc(#loc580) + %off_chz2_63 = arith.extsi %off_chz2_62 : i32 to i64 loc(#loc581) + %Q2 = tt.addptr %arg_Q, %q_adj2_54 : !tt.ptr, i64 loc(#loc582) + %DO2 = tt.addptr %arg_DO, %do_adj2_57 : !tt.ptr, i64 loc(#loc583) + %DQ2 = tt.addptr %arg_DQ, %dq_adj2_60 : !tt.ptr, i64 loc(#loc584) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_63 : !tt.ptr, i64 loc(#loc585) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_63 : !tt.ptr, i64 loc(#loc586) + %dq = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc587) + %start_m2 = arith.constant 128 : i32 loc(#loc588) + %start_m2_64 = arith.constant 128 : i32 loc(#loc588) + %start_m2_65 = arith.muli %start_m2_block, %start_m2_64 : i32 loc(#loc588) + %offs_m2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc589) + %offs_m2_66 = tt.splat %start_m2_65 : i32 -> tensor<128xi32> loc(#loc590) + %offs_m2_67 = arith.addi %offs_m2_66, %offs_m2 : tensor<128xi32> loc(#loc590) + %q = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%Q2, %offs_m2_67, %offs_k, %c4096_i32_1, %c1_i32, %ks0) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc591) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%DO2, %offs_m2_67, %offs_v, %c128_i32_23, %c1_i32_24, %ks0) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc592) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc593) + %Di_68 = arith.cmpi slt, %offs_m2_67, %Di : tensor<128xi32> loc(#loc593) + %Di_69 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc594) + %Di_70 = tt.addptr %Di_69, %offs_m2_67 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc594) + %Di_71 = tt.load %Di_70, %Di_68 : tensor<128x!tt.ptr> loc(#loc595) + %lse = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc596) + %lse_72 = arith.cmpi slt, %offs_m2_67, %lse : tensor<128xi32> loc(#loc596) + %lse_73 = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc597) + %lse_74 = tt.addptr %lse_73, %offs_m2_67 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc597) + %lse_75 = tt.load %lse_74, %lse_72 : tensor<128x!tt.ptr> loc(#loc598) + %lse_76 = arith.constant 0xFF800000 : f32 loc(#loc599) + %lse_77 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc599) + %lse_78 = arith.cmpf oeq, %lse_75, %lse_77 : tensor<128xf32> loc(#loc599) + %lse_79 = arith.constant 0.000000e+00 : f32 loc(#loc600) + %lse_80 = arith.constant 0.000000e+00 : f32 loc(#loc600) + %lse_81 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc600) + %lse_82 = arith.select %lse_78, %lse_81, %lse_75 : tensor<128xi1>, tensor<128xf32> loc(#loc600) + %lse_83 = tt.expand_dims %lse_82 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc601) + %kv_indices = tt.addptr %arg_KV_IDX, %sparse_kv_idx_offset_51 : !tt.ptr, i32 loc(#loc602) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc603) + %kv_start_84 = arith.constant 128 : i32 loc(#loc604) + %kv_start_85 = arith.constant 128 : i32 loc(#loc604) + %kv_start_86 = arith.muli %kv_start, %kv_start_85 : i32 loc(#loc604) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %sparse_kv_num_blks_offset_49 : !tt.ptr, i32 loc(#loc605) + %sparse_kv_num_blocks_87 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc606) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc607) + %offs_n2_88 = tt.splat %kv_start_86 : i32 -> tensor<64xi32> loc(#loc608) + %offs_n2_89 = arith.addi %offs_n2_88, %offs_n2 : tensor<64xi32> loc(#loc608) + %dq_90 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %K, %V, %dq, %q, %do, %Di_71, %lse_83, %off_zq, %off_hq2_47, %offs_m2_67, %offs_n2_89, %c1024_i32_4, %c1_i32_5, %c1024_i32_9, %c1_i32_10, %kv_indices, %sparse_kv_num_blocks_87) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc609) + %kv_indices_91 = tt.addptr %arg_FULL_KV_IDX, %sparse_kv_idx_offset_51 : !tt.ptr, i32 loc(#loc610) + %kv_start_92 = tt.load %kv_indices_91 : !tt.ptr loc(#loc611) + %kv_start_93 = arith.constant 128 : i32 loc(#loc612) + %kv_start_94 = arith.constant 128 : i32 loc(#loc612) + %kv_start_95 = arith.muli %kv_start_92, %kv_start_94 : i32 loc(#loc612) + %sparse_kv_num_blocks_96 = tt.addptr %arg_FULL_KV_NUM_BLKS, %sparse_kv_num_blks_offset_49 : !tt.ptr, i32 loc(#loc613) + %sparse_kv_num_blocks_97 = tt.load %sparse_kv_num_blocks_96 : !tt.ptr loc(#loc614) + %offs_n2_98 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc615) + %offs_n2_99 = tt.splat %kv_start_95 : i32 -> tensor<64xi32> loc(#loc616) + %offs_n2_100 = arith.addi %offs_n2_99, %offs_n2_98 : tensor<64xi32> loc(#loc616) + %dq_101 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %K, %V, %dq_90, %q, %do, %Di_71, %lse_83, %off_zq, %off_hq2_47, %offs_m2_67, %offs_n2_100, %c1024_i32_4, %c1_i32_5, %c1024_i32_9, %c1_i32_10, %kv_indices_91, %sparse_kv_num_blocks_97) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc617) + %dq_ptrs = tt.expand_dims %offs_m2_67 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc618) + %dq_ptrs_102 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc619) + %dq_ptrs_103 = arith.muli %dq_ptrs, %dq_ptrs_102 : tensor<128x1xi32> loc(#loc619) + %dq_ptrs_104 = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc620) + %dq_ptrs_105 = tt.addptr %dq_ptrs_104, %dq_ptrs_103 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc620) + %dq_ptrs_106 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc621) + %dq_ptrs_107 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc622) + %dq_ptrs_108 = arith.muli %dq_ptrs_106, %dq_ptrs_107 : tensor<1x128xi32> loc(#loc622) + %dq_ptrs_109 = tt.broadcast %dq_ptrs_105 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc623) + %dq_ptrs_110 = tt.broadcast %dq_ptrs_108 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc623) + %dq_ptrs_111 = tt.addptr %dq_ptrs_109, %dq_ptrs_110 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc623) + %dq_112 = arith.constant 0.0883883461 : f32 loc(#loc624) + %dq_113 = arith.constant 0.0883883461 : f32 loc(#loc624) + %dq_114 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc624) + %dq_115 = arith.mulf %dq_101, %dq_114 : tensor<128x128xf32> loc(#loc624) + %23 = tt.expand_dims %offs_m2_67 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc135) + %24 = tt.splat %ks0 : i32 -> tensor<128x1xi32> loc(#loc136) + %25 = arith.cmpi slt, %23, %24 : tensor<128x1xi32> loc(#loc136) + %26 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc137) + %c128_i32_116 = arith.constant 128 : i32 loc(#loc138) + %cst = arith.constant dense<128> : tensor<1x128xi32> loc(#loc138) + %27 = arith.cmpi slt, %26, %cst : tensor<1x128xi32> loc(#loc138) + %28 = tt.broadcast %25 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc139) + %29 = tt.broadcast %27 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc139) + %30 = arith.andi %28, %29 : tensor<128x128xi1> loc(#loc139) + %31 = arith.truncf %dq_115 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc140) + tt.store %dq_ptrs_111, %31, %30 : tensor<128x128x!tt.ptr> loc(#loc140) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc140) + } else { + %SPARSE_Q_MULTIPLE = arith.constant 2 : i32 loc(#loc1094) + %SPARSE_KV_MULTIPLE = arith.constant 1 : i32 loc(#loc1095) + %pid_mask = arith.divsi %pid, %SPARSE_KV_MULTIPLE : i32 loc(#loc627) + %stride_q_num_blks_h = arith.constant 1 : i32 loc(#loc628) + %stride_q_idx_h = arith.constant 1 : i32 loc(#loc629) + %stride_q_idx_n = arith.constant 1 : i32 loc(#loc630) + %dv = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc631) + %dk = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() : () -> tensor<128x128xf32> loc(#loc632) + %start_n1 = arith.constant 128 : i32 loc(#loc633) + %start_n1_44 = arith.constant 128 : i32 loc(#loc633) + %start_n1_45 = arith.muli %pid, %start_n1_44 : i32 loc(#loc633) + %offs_n1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc634) + %offs_n1_46 = tt.splat %start_n1_45 : i32 -> tensor<128xi32> loc(#loc635) + %offs_n1_47 = arith.addi %offs_n1_46, %offs_n1 : tensor<128xi32> loc(#loc635) + %k = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%K, %offs_n1_47, %offs_k, %c1024_i32_4, %c1_i32_5, %ks1) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc636) + %v = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%V, %offs_n1_47, %offs_v, %c1024_i32_9, %c1_i32_10, %ks1) : (!tt.ptr, tensor<128xi32>, tensor<128xi32>, i32, i32, i32) -> tensor<128x128xbf16> loc(#loc637) + %c0_i32 = arith.constant 0 : i32 loc(#loc154) + %c4_i32 = arith.constant 4 : i32 loc(#loc154) + %c1_i32_48 = arith.constant 1 : i32 loc(#loc154) + %23 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc154) + %24 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc154) + %25 = arith.bitcast %c1_i32_48 : i32 to i32 loc(#loc154) + %26 = ub.poison : i32 loc(#loc154) + %dk_49:2 = scf.for %off_g = %23 to %24 step %25 iter_args(%dv_89 = %dv, %dk_90 = %dk) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.constant 4 : i32 loc(#loc639) + %off_hq1_91 = arith.constant 4 : i32 loc(#loc639) + %off_hq1_92 = arith.muli %off_hkv, %off_hq1_91 : i32 loc(#loc639) + %off_hq1_93 = arith.addi %off_hq1_92, %off_g : i32 loc(#loc640) + %q_adj1 = arith.muli %c128_i32, %off_hq1_93 : i32 loc(#loc641) + %q_adj1_94 = arith.muli %0, %off_zq : i32 loc(#loc642) + %q_adj1_95 = arith.addi %q_adj1, %q_adj1_94 : i32 loc(#loc643) + %q_adj1_96 = arith.extsi %q_adj1_95 : i32 to i64 loc(#loc644) + %do_adj1 = arith.muli %18, %off_hq1_93 : i32 loc(#loc645) + %do_adj1_97 = arith.muli %10, %off_zq : i32 loc(#loc646) + %do_adj1_98 = arith.addi %do_adj1, %do_adj1_97 : i32 loc(#loc647) + %do_adj1_99 = arith.extsi %do_adj1_98 : i32 to i64 loc(#loc648) + %dq_adj1 = arith.muli %c128_i32_27, %off_hq1_93 : i32 loc(#loc649) + %dq_adj1_100 = arith.muli %19, %off_zq : i32 loc(#loc650) + %dq_adj1_101 = arith.addi %dq_adj1, %dq_adj1_100 : i32 loc(#loc651) + %dq_adj1_102 = arith.extsi %dq_adj1_101 : i32 to i64 loc(#loc652) + %off_chz1 = arith.muli %off_zq, %HQ : i32 loc(#loc653) + %off_chz1_103 = arith.addi %off_chz1, %off_hq1_93 : i32 loc(#loc654) + %off_chz1_104 = arith.muli %off_chz1_103, %ks0 : i32 loc(#loc655) + %off_chz1_105 = arith.extsi %off_chz1_104 : i32 to i64 loc(#loc656) + %Q1 = tt.addptr %arg_Q, %q_adj1_96 : !tt.ptr, i64 loc(#loc657) + %DO1 = tt.addptr %arg_DO, %do_adj1_99 : !tt.ptr, i64 loc(#loc658) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_105 : !tt.ptr, i64 loc(#loc659) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_105 : !tt.ptr, i64 loc(#loc660) + %sparse_idx_hq1 = arith.remsi %off_hq1_93, %SPARSE_HQ : i32 loc(#loc661) + %sparse_hz_offset = arith.muli %sparse_idx_z, %SPARSE_HQ : i32 loc(#loc662) + %sparse_hz_offset_106 = arith.addi %sparse_hz_offset, %sparse_idx_hq1 : i32 loc(#loc663) + %sparse_q_num_blks_offset = arith.muli %sparse_hz_offset_106, %stride_q_num_blks_h : i32 loc(#loc664) + %sparse_q_num_blks_offset_107 = arith.addi %sparse_q_num_blks_offset, %pid_mask : i32 loc(#loc665) + %sparse_q_idx_offset = arith.muli %sparse_hz_offset_106, %stride_q_idx_h : i32 loc(#loc666) + %sparse_q_idx_offset_108 = arith.muli %pid_mask, %stride_q_idx_n : i32 loc(#loc667) + %sparse_q_idx_offset_109 = arith.addi %sparse_q_idx_offset, %sparse_q_idx_offset_108 : i32 loc(#loc668) + %q_indices = tt.addptr %arg_Q_IDX, %sparse_q_idx_offset_109 : !tt.ptr, i32 loc(#loc669) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc670) + %q_start_110 = arith.constant 128 : i32 loc(#loc671) + %q_start_111 = arith.constant 128 : i32 loc(#loc671) + %q_start_112 = arith.muli %q_start, %q_start_111 : i32 loc(#loc671) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %sparse_q_num_blks_offset_107 : !tt.ptr, i32 loc(#loc672) + %sparse_q_num_blocks_113 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc673) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc674) + %offs_m1_114 = tt.splat %q_start_112 : i32 -> tensor<64xi32> loc(#loc675) + %offs_m1_115 = arith.addi %offs_m1_114, %offs_m1 : tensor<64xi32> loc(#loc675) + %45:2 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(37,)cconstexpr_bf16__(38,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %Q1, %DO1, %DELTA1, %LSE1, %dk_90, %dv_89, %k, %v, %off_zq, %off_hq1_93, %offs_n1_47, %offs_m1_115, %c4096_i32_1, %c1_i32, %c128_i32_23, %c1_i32_24, %q_indices, %sparse_q_num_blocks_113) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc192) + %q_indices_116 = tt.addptr %arg_FULL_Q_IDX, %sparse_q_idx_offset_109 : !tt.ptr, i32 loc(#loc676) + %q_start_117 = tt.load %q_indices_116 : !tt.ptr loc(#loc677) + %q_start_118 = arith.constant 128 : i32 loc(#loc678) + %q_start_119 = arith.constant 128 : i32 loc(#loc678) + %q_start_120 = arith.muli %q_start_117, %q_start_119 : i32 loc(#loc678) + %sparse_q_num_blocks_121 = tt.addptr %arg_FULL_Q_NUM_BLKS, %sparse_q_num_blks_offset_107 : !tt.ptr, i32 loc(#loc679) + %sparse_q_num_blocks_122 = tt.load %sparse_q_num_blocks_121 : !tt.ptr loc(#loc680) + %offs_m1_123 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc681) + %offs_m1_124 = tt.splat %q_start_120 : i32 -> tensor<64xi32> loc(#loc682) + %offs_m1_125 = arith.addi %offs_m1_124, %offs_m1_123 : tensor<64xi32> loc(#loc682) + %46:2 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(37,)cconstexpr_bf16__(38,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %Q1, %DO1, %DELTA1, %LSE1, %45#0, %45#1, %k, %v, %off_zq, %off_hq1_93, %offs_n1_47, %offs_m1_125, %c4096_i32_1, %c1_i32, %c128_i32_23, %c1_i32_24, %q_indices_116, %sparse_q_num_blocks_122) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x128xbf16>, i32, i32, tensor<128xi32>, tensor<64xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc200) + scf.yield %46#1, %46#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc201) + } loc(#loc1096) + %dv_ptrs = tt.expand_dims %offs_n1_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc683) + %dv_ptrs_50 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc684) + %dv_ptrs_51 = arith.muli %dv_ptrs, %dv_ptrs_50 : tensor<128x1xi32> loc(#loc684) + %dv_ptrs_52 = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc685) + %dv_ptrs_53 = tt.addptr %dv_ptrs_52, %dv_ptrs_51 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc685) + %dv_ptrs_54 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc686) + %dv_ptrs_55 = arith.constant dense<1> : tensor<1x128xi32> loc(#loc687) + %dv_ptrs_56 = arith.muli %dv_ptrs_54, %dv_ptrs_55 : tensor<1x128xi32> loc(#loc687) + %dv_ptrs_57 = tt.broadcast %dv_ptrs_53 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc688) + %dv_ptrs_58 = tt.broadcast %dv_ptrs_56 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc688) + %dv_ptrs_59 = tt.addptr %dv_ptrs_57, %dv_ptrs_58 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc688) + %index_n = tt.expand_dims %offs_n1_47 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc689) + %index_k = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc690) + %index_v = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc691) + %27 = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc211) + %28 = arith.cmpi slt, %index_n, %27 : tensor<128x1xi32> loc(#loc211) + %c128_i32_60 = arith.constant 128 : i32 loc(#loc212) + %cst = arith.constant dense<128> : tensor<1x128xi32> loc(#loc212) + %29 = arith.cmpi slt, %index_v, %cst : tensor<1x128xi32> loc(#loc212) + %30 = tt.broadcast %28 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc213) + %31 = tt.broadcast %29 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc213) + %32 = arith.andi %30, %31 : tensor<128x128xi1> loc(#loc213) + %33 = arith.truncf %dk_49#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc214) + tt.store %dv_ptrs_59, %33, %32 : tensor<128x128x!tt.ptr> loc(#loc214) + %dk_61 = arith.constant 0.0883883461 : f32 loc(#loc692) + %dk_62 = arith.constant 0.0883883461 : f32 loc(#loc692) + %dk_63 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc692) + %dk_64 = arith.mulf %dk_49#1, %dk_63 : tensor<128x128xf32> loc(#loc692) + %mask = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc693) + %mask_65 = arith.cmpi slt, %index_n, %mask : tensor<128x1xi32> loc(#loc693) + %xindex = arith.constant 128 : i32 loc(#loc694) + %xindex_66 = arith.constant 128 : i32 loc(#loc694) + %xindex_67 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc694) + %xindex_68 = arith.muli %xindex_67, %index_n : tensor<128x1xi32> loc(#loc694) + %xindex_69 = tt.broadcast %index_k : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc695) + %xindex_70 = tt.broadcast %xindex_68 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc695) + %xindex_71 = arith.addi %xindex_69, %xindex_70 : tensor<128x128xi32> loc(#loc695) + %xindex_72 = arith.constant 128 : i32 loc(#loc696) + %xindex_73 = arith.constant 128 : i32 loc(#loc696) + %xindex_74 = arith.muli %xindex_73, %off_hkv : i32 loc(#loc696) + %xindex_75 = arith.muli %xindex_74, %ks1 : i32 loc(#loc697) + %xindex_76 = tt.splat %xindex_75 : i32 -> tensor<128x128xi32> loc(#loc698) + %xindex_77 = arith.addi %xindex_71, %xindex_76 : tensor<128x128xi32> loc(#loc698) + %xindex_78 = arith.constant 1024 : i32 loc(#loc699) + %xindex_79 = arith.constant 1024 : i32 loc(#loc699) + %xindex_80 = arith.muli %xindex_79, %off_zq : i32 loc(#loc699) + %xindex_81 = arith.muli %xindex_80, %ks1 : i32 loc(#loc700) + %xindex_82 = tt.splat %xindex_81 : i32 -> tensor<128x128xi32> loc(#loc701) + %xindex_83 = arith.addi %xindex_77, %xindex_82 : tensor<128x128xi32> loc(#loc701) + %c128_i32_84 = arith.constant 128 : i32 loc(#loc225) + %c128_i32_85 = arith.constant 128 : i32 loc(#loc225) + %34 = arith.muli %c128_i32_85, %off_hkv : i32 loc(#loc225) + %35 = tt.splat %34 : i32 -> tensor<1x128xi32> loc(#loc226) + %36 = arith.addi %index_k, %35 : tensor<1x128xi32> loc(#loc226) + %c1024_i32_86 = arith.constant 1024 : i32 loc(#loc227) + %c1024_i32_87 = arith.constant 1024 : i32 loc(#loc227) + %cst_88 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc227) + %37 = arith.muli %cst_88, %index_n : tensor<128x1xi32> loc(#loc227) + %38 = tt.broadcast %36 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc228) + %39 = tt.broadcast %37 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc228) + %40 = arith.addi %38, %39 : tensor<128x128xi32> loc(#loc228) + %41 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc229) + %42 = tt.addptr %41, %40 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc229) + %43 = tt.broadcast %mask_65 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc230) + %44 = arith.truncf %dk_64 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc230) + tt.store %42, %44, %43 : tensor<128x128x!tt.ptr> loc(#loc230) + scf.yield %SPARSE_KV_MULTIPLE, %SPARSE_Q_MULTIPLE : i32, i32 loc(#loc230) + } loc(#loc56) + tt.return loc(#loc231) + } loc(#loc) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_128_"(%x: i32 loc("x"(#loc232))) -> i32 attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 loc(#loc233) + %c128_i32_0 = arith.constant 128 : i32 loc(#loc233) + %0 = arith.addi %x, %c128_i32_0 : i32 loc(#loc233) + %c1_i32 = arith.constant 1 : i32 loc(#loc234) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc234) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc234) + %c128_i32_2 = arith.constant 128 : i32 loc(#loc235) + %c128_i32_3 = arith.constant 128 : i32 loc(#loc235) + %2 = arith.divsi %1, %c128_i32_3 : i32 loc(#loc235) + tt.return %2 : i32 loc(#loc236) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc237) + tt.return %3 : i32 loc(#loc237) + } loc(#loc232) + tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_128__(0, 1)cconstexpr_128__(1,)cconstexpr_fp32_"() -> tensor<128x128xf32> attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f32 loc(#loc239) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc239) + tt.return %cst_0 : tensor<128x128xf32> loc(#loc240) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc241) + tt.return %0 : tensor<128x128xf32> loc(#loc241) + } loc(#loc238) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16_i32S128S_i32S128S_i32_i32_i32__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: !tt.ptr loc("ptr"(#loc242)), %offs_m: tensor<128xi32> loc("offs_m"(#loc242)), %offs_n: tensor<128xi32> loc("offs_n"(#loc242)), %stride_m: i32 loc("stride_m"(#loc242)), %stride_n: i32 loc("stride_n"(#loc242)), %M_LEN: i32 loc("M_LEN"(#loc242))) -> tensor<128x128xbf16> attributes {noinline = false} { + %ptr_0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc709) + %ptr_1 = tt.splat %stride_m : i32 -> tensor<128x1xi32> loc(#loc710) + %ptr_2 = arith.muli %ptr_0, %ptr_1 : tensor<128x1xi32> loc(#loc710) + %ptr_3 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc711) + %ptr_4 = tt.addptr %ptr_3, %ptr_2 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc711) + %ptr_5 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc712) + %ptr_6 = tt.splat %stride_n : i32 -> tensor<1x128xi32> loc(#loc713) + %ptr_7 = arith.muli %ptr_5, %ptr_6 : tensor<1x128xi32> loc(#loc713) + %ptr_8 = tt.broadcast %ptr_4 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc714) + %ptr_9 = tt.broadcast %ptr_7 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc714) + %ptr_10 = tt.addptr %ptr_8, %ptr_9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc714) + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc249) + %1 = tt.splat %M_LEN : i32 -> tensor<128x1xi32> loc(#loc250) + %2 = arith.cmpi slt, %0, %1 : tensor<128x1xi32> loc(#loc250) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc251) + %3 = tt.broadcast %2 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc251) + %cst_11 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc251) + %4 = arith.truncf %cst_11 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc251) + %5 = tt.load %ptr_10, %3, %4 : tensor<128x128x!tt.ptr> loc(#loc251) + tt.return %5 : tensor<128x128xbf16> loc(#loc252) + ^bb1: // no predecessors + %6 = ub.poison : tensor<128x128xbf16> loc(#loc253) + tt.return %6 : tensor<128x128xbf16> loc(#loc253) + } loc(#loc242) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc254)), %arg_K: !tt.ptr loc("arg_K"(#loc254)), %arg_V: !tt.ptr loc("arg_V"(#loc254)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc254)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc254)), %arg_DO: !tt.ptr loc("arg_DO"(#loc254)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc254)), %arg_DV: !tt.ptr loc("arg_DV"(#loc254)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc254)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc254)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc254)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc254)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc254)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc254)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc254)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc254)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc254)), %ks0: i32 loc("ks0"(#loc254)), %ks1: i32 loc("ks1"(#loc254)), %K: !tt.ptr loc("K"(#loc254)), %V: !tt.ptr loc("V"(#loc254)), %dq: tensor<128x128xf32> loc("dq"(#loc254)), %q: tensor<128x128xbf16> loc("q"(#loc254)), %do: tensor<128x128xbf16> loc("do"(#loc254)), %Di: tensor<128xf32> loc("Di"(#loc254)), %lse: tensor<128x1xf32> loc("lse"(#loc254)), %off_z: i32 loc("off_z"(#loc254)), %off_hq: i32 loc("off_hq"(#loc254)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc254)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc254)), %stride_kn: i32 loc("stride_kn"(#loc254)), %stride_kd: i32 loc("stride_kd"(#loc254)), %stride_vn: i32 loc("stride_vn"(#loc254)), %stride_vd: i32 loc("stride_vd"(#loc254)), %kv_indices: !tt.ptr loc("kv_indices"(#loc254)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc254))) -> tensor<128x128xf32> attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc751) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc752) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc753) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc754) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc754) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc755) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc755) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc756) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc757) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc757) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc758) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc758) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc758) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc759) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc760) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc760) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc761) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc761) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc762) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc763) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc763) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc764) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc764) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc764) + %hi = arith.constant 2 : i32 loc(#loc765) + %hi_20 = arith.constant 2 : i32 loc(#loc765) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc765) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc766) + %hi_23 = arith.constant 1 : i32 loc(#loc767) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc767) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc768) + %c0_i32 = arith.constant 0 : i32 loc(#loc273) + %c1_i32 = arith.constant 1 : i32 loc(#loc273) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc273) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc273) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc273) + %3 = ub.poison : i32 loc(#loc273) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %ks0, %ks1, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc770) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc771) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc772) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc773) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc773) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc774) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc775) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc775) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc776) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc776) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc281) + } loc(#loc1101) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc282) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc283) + tt.return %4 : tensor<128x128xf32> loc(#loc283) + } loc(#loc254) + tt.func private @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%x: i32 loc("x"(#loc232))) -> i32 attributes {noinline = false} { + %c64_i32 = arith.constant 64 : i32 loc(#loc233) + %c64_i32_0 = arith.constant 64 : i32 loc(#loc233) + %0 = arith.addi %x, %c64_i32_0 : i32 loc(#loc233) + %c1_i32 = arith.constant 1 : i32 loc(#loc234) + %c1_i32_1 = arith.constant 1 : i32 loc(#loc234) + %1 = arith.subi %0, %c1_i32_1 : i32 loc(#loc234) + %c64_i32_2 = arith.constant 64 : i32 loc(#loc235) + %c64_i32_3 = arith.constant 64 : i32 loc(#loc235) + %2 = arith.divsi %1, %c64_i32_3 : i32 loc(#loc235) + tt.return %2 : i32 loc(#loc236) + ^bb1: // no predecessors + %3 = ub.poison : i32 loc(#loc237) + tt.return %3 : i32 loc(#loc237) + } loc(#loc232) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc284)), %arg_K: !tt.ptr loc("arg_K"(#loc284)), %arg_V: !tt.ptr loc("arg_V"(#loc284)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc284)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc284)), %arg_DO: !tt.ptr loc("arg_DO"(#loc284)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc284)), %arg_DV: !tt.ptr loc("arg_DV"(#loc284)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc284)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc284)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc284)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc284)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc284)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc284)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc284)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc284)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc284)), %ks0: i32 loc("ks0"(#loc284)), %ks1: i32 loc("ks1"(#loc284)), %dq: tensor<128x128xf32> loc("dq"(#loc284)), %q: tensor<128x128xbf16> loc("q"(#loc284)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc284)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc284)), %do: tensor<128x128xbf16> loc("do"(#loc284)), %Di: tensor<128xf32> loc("Di"(#loc284)), %lse: tensor<128x1xf32> loc("lse"(#loc284)), %Q_LEN: i32 loc("Q_LEN"(#loc284)), %KV_LEN: i32 loc("KV_LEN"(#loc284)), %off_z: i32 loc("off_z"(#loc284)), %off_hq: i32 loc("off_hq"(#loc284)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc284)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc284)), %offs_k: tensor<128xi32> loc("offs_k"(#loc284)), %offs_v: tensor<128xi32> loc("offs_v"(#loc284)), %stride_kn: i32 loc("stride_kn"(#loc284)), %stride_kd: i32 loc("stride_kd"(#loc284)), %stride_vn: i32 loc("stride_vn"(#loc284)), %stride_vd: i32 loc("stride_vd"(#loc284)), %kv_indices: !tt.ptr loc("kv_indices"(#loc284)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc284))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc817) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc818) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc818) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc818) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc819) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc819) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc819) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc819) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc820) + %n_6 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S1_64S_i32__(%n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc821) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc822) + %m_7 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S128_1S_i32__(%m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc823) + %post_mod_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc824) + %post_mod_scores_8 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc825) + %post_mod_scores_9 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_8 : tensor<1x64xi32> loc(#loc825) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc826) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc826) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc826) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc826) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_5, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc826) + %tmp2 = arith.constant 0 : i32 loc(#loc827) + %tmp2_15 = arith.constant dense<0> : tensor<1xi32> loc(#loc827) + %tmp3 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc828) + %tmp3_16 = tt.broadcast %tmp3 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc828) + %tmp3_17 = arith.cmpi slt, %m_7, %tmp3_16 : tensor<128x1xi32> loc(#loc828) + %tmp5 = tt.broadcast %n_6 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc829) + %tmp5_18 = tt.broadcast %m_7 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc829) + %tmp5_19 = arith.cmpi sle, %tmp5, %tmp5_18 : tensor<128x64xi32> loc(#loc829) + %tmp6 = tt.broadcast %tmp3_17 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc830) + %tmp6_20 = arith.andi %tmp6, %tmp5_19 : tensor<128x64xi1> loc(#loc830) + %tmp7 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc831) + %tmp7_21 = tt.broadcast %tmp7 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc831) + %tmp7_22 = arith.cmpi sge, %m_7, %tmp7_21 : tensor<128x1xi32> loc(#loc831) + %tmp8 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc832) + %tmp8_23 = tt.broadcast %tmp8 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc832) + %tmp8_24 = arith.cmpi slt, %n_6, %tmp8_23 : tensor<1x64xi32> loc(#loc832) + %tmp9 = tt.broadcast %tmp7_22 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc833) + %tmp9_25 = tt.broadcast %tmp8_24 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc833) + %tmp9_26 = arith.andi %tmp9, %tmp9_25 : tensor<128x64xi1> loc(#loc833) + %tmp10 = arith.constant 0 : i32 loc(#loc834) + %tmp10_27 = arith.extui %tmp8_24 : tensor<1x64xi1> to tensor<1x64xi32> loc(#loc834) + %tmp10_28 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc834) + %tmp10_29 = arith.cmpi eq, %tmp10_27, %tmp10_28 : tensor<1x64xi32> loc(#loc834) + %tmp11 = tt.broadcast %tmp7_22 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc835) + %tmp11_30 = tt.broadcast %tmp10_29 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc835) + %tmp11_31 = arith.andi %tmp11, %tmp11_30 : tensor<128x64xi1> loc(#loc835) + %tmp12 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc836) + %tmp12_32 = tt.broadcast %tmp12 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc836) + %tmp12_33 = arith.subi %m_7, %tmp12_32 : tensor<128x1xi32> loc(#loc836) + %tmp13 = arith.constant 16 : i32 loc(#loc837) + %tmp13_34 = arith.constant dense<16> : tensor<1xi32> loc(#loc837) + %tmp14 = arith.constant 0 : i32 loc(#loc838) + %tmp14_35 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc838) + %tmp14_36 = arith.cmpi slt, %tmp12_33, %tmp14_35 : tensor<128x1xi32> loc(#loc838) + %tmp14_37 = arith.constant 0 : i32 loc(#loc839) + %tmp14_38 = arith.constant dense<0> : tensor<1xi32> loc(#loc839) + %tmp14_39 = arith.cmpi slt, %tmp13_34, %tmp14_38 : tensor<1xi32> loc(#loc839) + %tmp14_40 = tt.expand_dims %tmp14_39 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc840) + %tmp14_41 = tt.broadcast %tmp14_40 : tensor<1x1xi1> -> tensor<128x1xi1> loc(#loc840) + %tmp14_42 = arith.cmpi ne, %tmp14_36, %tmp14_41 : tensor<128x1xi1> loc(#loc840) + %tmp14_43 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc841) + %tmp14_44 = tt.broadcast %tmp14_43 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc841) + %tmp14_45 = arith.remsi %tmp12_33, %tmp14_44 : tensor<128x1xi32> loc(#loc841) + %tmp14_46 = arith.constant 0 : i32 loc(#loc842) + %tmp14_47 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc842) + %tmp14_48 = arith.cmpi ne, %tmp14_45, %tmp14_47 : tensor<128x1xi32> loc(#loc842) + %tmp14_49 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc843) + %tmp14_50 = tt.broadcast %tmp14_49 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc843) + %tmp14_51 = arith.divsi %tmp12_33, %tmp14_50 : tensor<128x1xi32> loc(#loc843) + %tmp14_52 = arith.constant 1 : i32 loc(#loc844) + %tmp14_53 = arith.constant 1 : i32 loc(#loc844) + %tmp14_54 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc844) + %tmp14_55 = arith.subi %tmp14_51, %tmp14_54 : tensor<128x1xi32> loc(#loc844) + %tmp14_56 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc845) + %tmp14_57 = tt.broadcast %tmp14_56 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc845) + %tmp14_58 = arith.divsi %tmp12_33, %tmp14_57 : tensor<128x1xi32> loc(#loc845) + %tmp14_59 = arith.select %tmp14_48, %tmp14_55, %tmp14_58 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc846) + %tmp14_60 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc847) + %tmp14_61 = tt.broadcast %tmp14_60 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc847) + %tmp14_62 = arith.divsi %tmp12_33, %tmp14_61 : tensor<128x1xi32> loc(#loc847) + %tmp14_63 = arith.select %tmp14_42, %tmp14_59, %tmp14_62 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc848) + %tmp15 = tt.expand_dims %tmp2_15 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc849) + %tmp15_64 = tt.broadcast %tmp15 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc849) + %tmp15_65 = arith.subi %n_6, %tmp15_64 : tensor<1x64xi32> loc(#loc849) + %tmp16 = arith.constant 0 : i32 loc(#loc850) + %tmp16_66 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc850) + %tmp16_67 = arith.cmpi slt, %tmp15_65, %tmp16_66 : tensor<1x64xi32> loc(#loc850) + %tmp16_68 = arith.constant 0 : i32 loc(#loc851) + %tmp16_69 = arith.constant dense<0> : tensor<1xi32> loc(#loc851) + %tmp16_70 = arith.cmpi slt, %tmp13_34, %tmp16_69 : tensor<1xi32> loc(#loc851) + %tmp16_71 = tt.expand_dims %tmp16_70 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc852) + %tmp16_72 = tt.broadcast %tmp16_71 : tensor<1x1xi1> -> tensor<1x64xi1> loc(#loc852) + %tmp16_73 = arith.cmpi ne, %tmp16_67, %tmp16_72 : tensor<1x64xi1> loc(#loc852) + %tmp16_74 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc853) + %tmp16_75 = tt.broadcast %tmp16_74 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc853) + %tmp16_76 = arith.remsi %tmp15_65, %tmp16_75 : tensor<1x64xi32> loc(#loc853) + %tmp16_77 = arith.constant 0 : i32 loc(#loc854) + %tmp16_78 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc854) + %tmp16_79 = arith.cmpi ne, %tmp16_76, %tmp16_78 : tensor<1x64xi32> loc(#loc854) + %tmp16_80 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc855) + %tmp16_81 = tt.broadcast %tmp16_80 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc855) + %tmp16_82 = arith.divsi %tmp15_65, %tmp16_81 : tensor<1x64xi32> loc(#loc855) + %tmp16_83 = arith.constant 1 : i32 loc(#loc856) + %tmp16_84 = arith.constant 1 : i32 loc(#loc856) + %tmp16_85 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc856) + %tmp16_86 = arith.subi %tmp16_82, %tmp16_85 : tensor<1x64xi32> loc(#loc856) + %tmp16_87 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc857) + %tmp16_88 = tt.broadcast %tmp16_87 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc857) + %tmp16_89 = arith.divsi %tmp15_65, %tmp16_88 : tensor<1x64xi32> loc(#loc857) + %tmp16_90 = arith.select %tmp16_79, %tmp16_86, %tmp16_89 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc858) + %tmp16_91 = tt.expand_dims %tmp13_34 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc859) + %tmp16_92 = tt.broadcast %tmp16_91 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc859) + %tmp16_93 = arith.divsi %tmp15_65, %tmp16_92 : tensor<1x64xi32> loc(#loc859) + %tmp16_94 = arith.select %tmp16_73, %tmp16_90, %tmp16_93 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc860) + %tmp17 = tt.broadcast %tmp14_63 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc861) + %tmp17_95 = tt.broadcast %tmp16_94 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc861) + %tmp17_96 = arith.cmpi eq, %tmp17, %tmp17_95 : tensor<128x64xi32> loc(#loc861) + %tmp18 = arith.andi %tmp11_31, %tmp17_96 : tensor<128x64xi1> loc(#loc862) + %tmp19 = arith.ori %tmp9_26, %tmp18 : tensor<128x64xi1> loc(#loc863) + %tmp20 = arith.ori %tmp6_20, %tmp19 : tensor<128x64xi1> loc(#loc864) + %post_mod_scores_97 = arith.constant 0xFF800000 : f32 loc(#loc865) + %post_mod_scores_98 = arith.constant 0xFF800000 : f32 loc(#loc865) + %post_mod_scores_99 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc865) + %post_mod_scores_100 = arith.select %tmp20, %post_mod_scores_14, %post_mod_scores_99 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc865) + %post_mod_scores_101 = arith.constant 1.44269502 : f32 loc(#loc866) + %post_mod_scores_102 = arith.constant 1.44269502 : f32 loc(#loc866) + %post_mod_scores_103 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc866) + %post_mod_scores_104 = arith.mulf %post_mod_scores_100, %post_mod_scores_103 : tensor<128x64xf32> loc(#loc866) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc867) + %p_105 = arith.subf %post_mod_scores_104, %p : tensor<128x64xf32> loc(#loc867) + %p_106 = math.exp2 %p_105 : tensor<128x64xf32> loc(#loc868) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc869) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc870) + %dp_107 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc870) + %dp_108 = tt.dot %do, %vT, %dp_107, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc870) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc871) + %ds_109 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc872) + %ds_110 = arith.subf %dp_108, %ds_109 : tensor<128x64xf32> loc(#loc872) + %ds_111 = arith.mulf %p_106, %ds_110 : tensor<128x64xf32> loc(#loc873) + %grad_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc874) + %grad_scores_112 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc875) + %grad_scores_113 = arith.cmpi slt, %grad_scores, %grad_scores_112 : tensor<1x64xi32> loc(#loc875) + %grad_scores_114 = arith.constant 0.000000e+00 : f32 loc(#loc876) + %grad_scores_115 = arith.constant 0.000000e+00 : f32 loc(#loc876) + %grad_scores_116 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc876) + %grad_scores_117 = tt.broadcast %grad_scores_113 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc876) + %grad_scores_118 = arith.select %grad_scores_117, %ds_111, %grad_scores_116 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc876) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc877) + %scatter_mask_119 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc878) + %scatter_mask_120 = arith.cmpi slt, %scatter_mask, %scatter_mask_119 : tensor<128x1xi32> loc(#loc878) + %scatter_mask_121 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc879) + %scatter_mask_122 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc880) + %scatter_mask_123 = arith.cmpi slt, %scatter_mask_121, %scatter_mask_122 : tensor<1x64xi32> loc(#loc880) + %scatter_mask_124 = tt.broadcast %scatter_mask_120 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc881) + %scatter_mask_125 = tt.broadcast %scatter_mask_123 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc881) + %scatter_mask_126 = arith.andi %scatter_mask_124, %scatter_mask_125 : tensor<128x64xi1> loc(#loc881) + %ds_127 = arith.constant 0.000000e+00 : f32 loc(#loc882) + %ds_128 = arith.constant 0.000000e+00 : f32 loc(#loc882) + %ds_129 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc882) + %ds_130 = arith.select %tmp20, %grad_scores_118, %ds_129 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc882) + %ds_131 = arith.truncf %ds_130 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc883) + %dq_132 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc884) + %dq_133 = arith.constant 0.000000e+00 : f32 loc(#loc885) + %dq_134 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc885) + %dq_135 = tt.dot %ds_131, %dq_132, %dq_134, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc885) + %dq_136 = arith.addf %dq, %dq_135 : tensor<128x128xf32> loc(#loc886) + tt.return %dq_136 : tensor<128x128xf32> loc(#loc355) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc356) + tt.return %0 : tensor<128x128xf32> loc(#loc356) + } loc(#loc284) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%ptr: tensor<128x64x!tt.ptr> loc("ptr"(#loc242)), %offs_m: tensor<128xi32> loc("offs_m"(#loc242)), %offs_n: tensor<64xi32> loc("offs_n"(#loc242)), %N_LEN: i32 loc("N_LEN"(#loc242))) -> tensor<128x64xbf16> attributes {noinline = false} { + %0 = tt.expand_dims %offs_n {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc357) + %1 = tt.splat %N_LEN : i32 -> tensor<1x64xi32> loc(#loc358) + %2 = arith.cmpi slt, %0, %1 : tensor<1x64xi32> loc(#loc358) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc359) + %3 = tt.broadcast %2 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc359) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc359) + %4 = arith.truncf %cst_0 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc359) + %5 = tt.load %ptr, %3, %4 : tensor<128x64x!tt.ptr> loc(#loc359) + tt.return %5 : tensor<128x64xbf16> loc(#loc360) + ^bb1: // no predecessors + %6 = ub.poison : tensor<128x64xbf16> loc(#loc253) + tt.return %6 : tensor<128x64xbf16> loc(#loc253) + } loc(#loc242) + tt.func private @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S1_64S_i32__(%indices: tensor<1x64xi32> loc("indices"(#loc361)), %max_len: i32 loc("max_len"(#loc361))) -> tensor<1x64xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<1x64xi32> loc(#loc362) + %1 = arith.remsi %indices, %0 : tensor<1x64xi32> loc(#loc362) + tt.return %1 : tensor<1x64xi32> loc(#loc363) + ^bb1: // no predecessors + %2 = ub.poison : tensor<1x64xi32> loc(#loc364) + tt.return %2 : tensor<1x64xi32> loc(#loc364) + } loc(#loc361) + tt.func private @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S128_1S_i32__(%indices: tensor<128x1xi32> loc("indices"(#loc361)), %max_len: i32 loc("max_len"(#loc361))) -> tensor<128x1xi32> attributes {noinline = false} { + %0 = tt.splat %max_len : i32 -> tensor<128x1xi32> loc(#loc362) + %1 = arith.remsi %indices, %0 : tensor<128x1xi32> loc(#loc362) + tt.return %1 : tensor<128x1xi32> loc(#loc363) + ^bb1: // no predecessors + %2 = ub.poison : tensor<128x1xi32> loc(#loc364) + tt.return %2 : tensor<128x1xi32> loc(#loc364) + } loc(#loc361) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%loop_iter: i32 loc("loop_iter"(#loc365)), %col_indices: !tt.ptr loc("col_indices"(#loc365)), %total_blocks: i32 loc("total_blocks"(#loc365))) -> i32 attributes {noinline = false} { + %cur_block_idx = arith.constant 2 : i32 loc(#loc893) + %cur_block_idx_0 = arith.constant 2 : i32 loc(#loc893) + %cur_block_idx_1 = arith.divsi %loop_iter, %cur_block_idx_0 : i32 loc(#loc893) + %cur_block = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc894) + %cur_block_2 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc895) + %next_block = arith.constant 1 : i32 loc(#loc896) + %next_block_3 = arith.constant 1 : i32 loc(#loc896) + %next_block_4 = arith.addi %cur_block_idx_1, %next_block_3 : i32 loc(#loc896) + %next_block_5 = arith.cmpi slt, %next_block_4, %total_blocks : i32 loc(#loc897) + %next_block_6 = tt.addptr %col_indices, %cur_block_idx_1 : !tt.ptr, i32 loc(#loc898) + %next_block_7 = arith.constant 1 : i32 loc(#loc899) + %next_block_8 = tt.addptr %next_block_6, %next_block_7 : !tt.ptr, i32 loc(#loc899) + %next_block_9 = tt.load %next_block_8, %next_block_5 evictionPolicy = evict_last : !tt.ptr loc(#loc900) + %needs_jump = arith.constant 1 : i32 loc(#loc901) + %needs_jump_10 = arith.constant 1 : i32 loc(#loc901) + %needs_jump_11 = arith.addi %loop_iter, %needs_jump_10 : i32 loc(#loc901) + %needs_jump_12 = arith.constant 2 : i32 loc(#loc902) + %needs_jump_13 = arith.constant 2 : i32 loc(#loc902) + %needs_jump_14 = arith.remsi %needs_jump_11, %needs_jump_13 : i32 loc(#loc902) + %needs_jump_15 = arith.constant 0 : i32 loc(#loc903) + %needs_jump_16 = arith.cmpi eq, %needs_jump_14, %needs_jump_15 : i32 loc(#loc903) + %jump_to_block = arith.subi %next_block_9, %cur_block_2 : i32 loc(#loc904) + %jump_to_block_17 = arith.constant 128 : i32 loc(#loc905) + %jump_to_block_18 = arith.constant 128 : i32 loc(#loc905) + %jump_to_block_19 = arith.muli %jump_to_block, %jump_to_block_18 : i32 loc(#loc905) + %jump_to_block_20 = arith.constant 64 : i32 loc(#loc906) + %jump_to_block_21 = arith.constant 64 : i32 loc(#loc906) + %jump_to_block_22 = arith.subi %jump_to_block_19, %jump_to_block_21 : i32 loc(#loc906) + %offset = arith.extui %needs_jump_16 : i1 to i32 loc(#loc907) + %offset_23 = arith.muli %jump_to_block_22, %offset : i32 loc(#loc907) + %offset_24 = arith.constant 1 : i32 loc(#loc908) + %offset_25 = arith.constant 1 : i32 loc(#loc908) + %offset_26 = arith.extui %needs_jump_16 : i1 to i32 loc(#loc908) + %offset_27 = arith.subi %offset_25, %offset_26 : i32 loc(#loc908) + %offset_28 = arith.constant 64 : i32 loc(#loc909) + %offset_29 = arith.constant 64 : i32 loc(#loc909) + %offset_30 = arith.muli %offset_27, %offset_29 : i32 loc(#loc909) + %offset_31 = arith.addi %offset_23, %offset_30 : i32 loc(#loc910) + tt.return %offset_31 : i32 loc(#loc384) + ^bb1: // no predecessors + %0 = ub.poison : i32 loc(#loc385) + tt.return %0 : i32 loc(#loc385) + } loc(#loc365) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_fp32S128_128S_bf16S128_128S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(36,)cconstexpr_bf16__(37,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc254)), %arg_K: !tt.ptr loc("arg_K"(#loc254)), %arg_V: !tt.ptr loc("arg_V"(#loc254)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc254)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc254)), %arg_DO: !tt.ptr loc("arg_DO"(#loc254)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc254)), %arg_DV: !tt.ptr loc("arg_DV"(#loc254)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc254)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc254)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc254)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc254)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc254)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc254)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc254)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc254)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc254)), %ks0: i32 loc("ks0"(#loc254)), %ks1: i32 loc("ks1"(#loc254)), %K: !tt.ptr loc("K"(#loc254)), %V: !tt.ptr loc("V"(#loc254)), %dq: tensor<128x128xf32> loc("dq"(#loc254)), %q: tensor<128x128xbf16> loc("q"(#loc254)), %do: tensor<128x128xbf16> loc("do"(#loc254)), %Di: tensor<128xf32> loc("Di"(#loc254)), %lse: tensor<128x1xf32> loc("lse"(#loc254)), %off_z: i32 loc("off_z"(#loc254)), %off_hq: i32 loc("off_hq"(#loc254)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc254)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc254)), %stride_kn: i32 loc("stride_kn"(#loc254)), %stride_kd: i32 loc("stride_kd"(#loc254)), %stride_vn: i32 loc("stride_vn"(#loc254)), %stride_vd: i32 loc("stride_vd"(#loc254)), %kv_indices: !tt.ptr loc("kv_indices"(#loc254)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc254))) -> tensor<128x128xf32> attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc751) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc752) + %kT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc753) + %kT_ptrs_0 = tt.splat %stride_kn : i32 -> tensor<1x64xi32> loc(#loc754) + %kT_ptrs_1 = arith.muli %kT_ptrs, %kT_ptrs_0 : tensor<1x64xi32> loc(#loc754) + %kT_ptrs_2 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc755) + %kT_ptrs_3 = tt.addptr %kT_ptrs_2, %kT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc755) + %kT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc756) + %kT_ptrs_5 = tt.splat %stride_kd : i32 -> tensor<128x1xi32> loc(#loc757) + %kT_ptrs_6 = arith.muli %kT_ptrs_4, %kT_ptrs_5 : tensor<128x1xi32> loc(#loc757) + %kT_ptrs_7 = tt.broadcast %kT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc758) + %kT_ptrs_8 = tt.broadcast %kT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc758) + %kT_ptrs_9 = tt.addptr %kT_ptrs_7, %kT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc758) + %vT_ptrs = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc759) + %vT_ptrs_10 = tt.splat %stride_vn : i32 -> tensor<1x64xi32> loc(#loc760) + %vT_ptrs_11 = arith.muli %vT_ptrs, %vT_ptrs_10 : tensor<1x64xi32> loc(#loc760) + %vT_ptrs_12 = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc761) + %vT_ptrs_13 = tt.addptr %vT_ptrs_12, %vT_ptrs_11 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc761) + %vT_ptrs_14 = tt.expand_dims %offs_v {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc762) + %vT_ptrs_15 = tt.splat %stride_vd : i32 -> tensor<128x1xi32> loc(#loc763) + %vT_ptrs_16 = arith.muli %vT_ptrs_14, %vT_ptrs_15 : tensor<128x1xi32> loc(#loc763) + %vT_ptrs_17 = tt.broadcast %vT_ptrs_13 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc764) + %vT_ptrs_18 = tt.broadcast %vT_ptrs_16 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc764) + %vT_ptrs_19 = tt.addptr %vT_ptrs_17, %vT_ptrs_18 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc764) + %hi = arith.constant 2 : i32 loc(#loc765) + %hi_20 = arith.constant 2 : i32 loc(#loc765) + %hi_21 = arith.muli %sparse_kv_num_blocks, %hi_20 : i32 loc(#loc765) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks1) : (i32) -> i32 loc(#loc766) + %hi_23 = arith.constant 1 : i32 loc(#loc767) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc767) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc768) + %c0_i32 = arith.constant 0 : i32 loc(#loc273) + %c1_i32 = arith.constant 1 : i32 loc(#loc273) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc273) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc273) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc273) + %3 = ub.poison : i32 loc(#loc273) + %vT_ptrs_26:4 = scf.for %start_n = %0 to %1 step %2 iter_args(%dq_27 = %dq, %offs_n2_28 = %offs_n2, %kT_ptrs_29 = %kT_ptrs_9, %vT_ptrs_30 = %vT_ptrs_19) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %dq_31 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %dq_27, %q, %kT_ptrs_29, %vT_ptrs_30, %do, %Di, %lse, %ks0, %ks1, %off_z, %off_hq, %offs_m2, %offs_n2_28, %offs_k, %offs_v, %stride_kn, %stride_kd, %stride_vn, %stride_vd, %kv_indices, %sparse_kv_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128x128xbf16>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128xf32>, tensor<128x1xf32>, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> tensor<128x128xf32> loc(#loc770) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_n, %kv_indices, %sparse_kv_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc771) + %kT_ptrs_32 = arith.muli %offset, %stride_kn : i32 loc(#loc772) + %kT_ptrs_33 = tt.splat %kT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc773) + %kT_ptrs_34 = tt.addptr %kT_ptrs_29, %kT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc773) + %vT_ptrs_35 = arith.muli %offset, %stride_vn : i32 loc(#loc774) + %vT_ptrs_36 = tt.splat %vT_ptrs_35 : i32 -> tensor<128x64xi32> loc(#loc775) + %vT_ptrs_37 = tt.addptr %vT_ptrs_30, %vT_ptrs_36 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc775) + %offs_n2_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc776) + %offs_n2_39 = arith.addi %offs_n2_28, %offs_n2_38 : tensor<64xi32> loc(#loc776) + scf.yield %dq_31, %offs_n2_39, %kT_ptrs_34, %vT_ptrs_37 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc281) + } loc(#loc1101) + tt.return %vT_ptrs_26#0 : tensor<128x128xf32> loc(#loc282) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc283) + tt.return %4 : tensor<128x128xf32> loc(#loc283) + } loc(#loc254) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dq_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_bf16S128_128S_Pbf16S128_64S_Pbf16S128_64S_bf16S128_128S_fp32S128S_fp32S128_1S_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(40,)cconstexpr_bf16__(41,)cconstexpr_1_d_44269504__(42,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc284)), %arg_K: !tt.ptr loc("arg_K"(#loc284)), %arg_V: !tt.ptr loc("arg_V"(#loc284)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc284)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc284)), %arg_DO: !tt.ptr loc("arg_DO"(#loc284)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc284)), %arg_DV: !tt.ptr loc("arg_DV"(#loc284)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc284)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc284)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc284)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc284)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc284)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc284)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc284)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc284)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc284)), %ks0: i32 loc("ks0"(#loc284)), %ks1: i32 loc("ks1"(#loc284)), %dq: tensor<128x128xf32> loc("dq"(#loc284)), %q: tensor<128x128xbf16> loc("q"(#loc284)), %kT_ptrs: tensor<128x64x!tt.ptr> loc("kT_ptrs"(#loc284)), %vT_ptrs: tensor<128x64x!tt.ptr> loc("vT_ptrs"(#loc284)), %do: tensor<128x128xbf16> loc("do"(#loc284)), %Di: tensor<128xf32> loc("Di"(#loc284)), %lse: tensor<128x1xf32> loc("lse"(#loc284)), %Q_LEN: i32 loc("Q_LEN"(#loc284)), %KV_LEN: i32 loc("KV_LEN"(#loc284)), %off_z: i32 loc("off_z"(#loc284)), %off_hq: i32 loc("off_hq"(#loc284)), %offs_m2: tensor<128xi32> loc("offs_m2"(#loc284)), %offs_n2: tensor<64xi32> loc("offs_n2"(#loc284)), %offs_k: tensor<128xi32> loc("offs_k"(#loc284)), %offs_v: tensor<128xi32> loc("offs_v"(#loc284)), %stride_kn: i32 loc("stride_kn"(#loc284)), %stride_kd: i32 loc("stride_kd"(#loc284)), %stride_vn: i32 loc("stride_vn"(#loc284)), %stride_vd: i32 loc("stride_vd"(#loc284)), %kv_indices: !tt.ptr loc("kv_indices"(#loc284)), %sparse_kv_num_blocks: i32 loc("sparse_kv_num_blocks"(#loc284))) -> tensor<128x128xf32> attributes {noinline = false} { + %kT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%kT_ptrs, %offs_k, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc817) + %qk = arith.constant 0.000000e+00 : f32 loc(#loc818) + %qk_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc818) + %qk_1 = tt.dot %q, %kT, %qk_0, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc818) + %qk_2 = arith.constant 0.0883883461 : f32 loc(#loc819) + %qk_3 = arith.constant 0.0883883461 : f32 loc(#loc819) + %qk_4 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc819) + %qk_5 = arith.mulf %qk_1, %qk_4 : tensor<128x64xf32> loc(#loc819) + %n = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc820) + %n_6 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S1_64S_i32__(%n, %KV_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc821) + %m = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc822) + %m_7 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S128_1S_i32__(%m, %Q_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc823) + %post_mod_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc824) + %post_mod_scores_8 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc825) + %post_mod_scores_9 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_8 : tensor<1x64xi32> loc(#loc825) + %post_mod_scores_10 = arith.constant 0xFF800000 : f32 loc(#loc826) + %post_mod_scores_11 = arith.constant 0xFF800000 : f32 loc(#loc826) + %post_mod_scores_12 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc826) + %post_mod_scores_13 = tt.broadcast %post_mod_scores_9 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc826) + %post_mod_scores_14 = arith.select %post_mod_scores_13, %qk_5, %post_mod_scores_12 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc826) + %post_mod_scores_15 = arith.constant 1.44269502 : f32 loc(#loc866) + %post_mod_scores_16 = arith.constant 1.44269502 : f32 loc(#loc866) + %post_mod_scores_17 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc866) + %post_mod_scores_18 = arith.mulf %post_mod_scores_14, %post_mod_scores_17 : tensor<128x64xf32> loc(#loc866) + %p = tt.broadcast %lse : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc867) + %p_19 = arith.subf %post_mod_scores_18, %p : tensor<128x64xf32> loc(#loc867) + %p_20 = math.exp2 %p_19 : tensor<128x64xf32> loc(#loc868) + %vT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%vT_ptrs, %offs_v, %offs_n2, %KV_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc869) + %dp = arith.constant 0.000000e+00 : f32 loc(#loc870) + %dp_21 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc870) + %dp_22 = tt.dot %do, %vT, %dp_21, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc870) + %ds = tt.expand_dims %Di {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc871) + %ds_23 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc872) + %ds_24 = arith.subf %dp_22, %ds_23 : tensor<128x64xf32> loc(#loc872) + %ds_25 = arith.mulf %p_20, %ds_24 : tensor<128x64xf32> loc(#loc873) + %grad_scores = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc874) + %grad_scores_26 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc875) + %grad_scores_27 = arith.cmpi slt, %grad_scores, %grad_scores_26 : tensor<1x64xi32> loc(#loc875) + %grad_scores_28 = arith.constant 0.000000e+00 : f32 loc(#loc876) + %grad_scores_29 = arith.constant 0.000000e+00 : f32 loc(#loc876) + %grad_scores_30 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc876) + %grad_scores_31 = tt.broadcast %grad_scores_27 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc876) + %grad_scores_32 = arith.select %grad_scores_31, %ds_25, %grad_scores_30 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc876) + %scatter_mask = tt.expand_dims %offs_m2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc877) + %scatter_mask_33 = tt.splat %Q_LEN : i32 -> tensor<128x1xi32> loc(#loc878) + %scatter_mask_34 = arith.cmpi slt, %scatter_mask, %scatter_mask_33 : tensor<128x1xi32> loc(#loc878) + %scatter_mask_35 = tt.expand_dims %offs_n2 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc879) + %scatter_mask_36 = tt.splat %KV_LEN : i32 -> tensor<1x64xi32> loc(#loc880) + %scatter_mask_37 = arith.cmpi slt, %scatter_mask_35, %scatter_mask_36 : tensor<1x64xi32> loc(#loc880) + %scatter_mask_38 = tt.broadcast %scatter_mask_34 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc881) + %scatter_mask_39 = tt.broadcast %scatter_mask_37 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc881) + %scatter_mask_40 = arith.andi %scatter_mask_38, %scatter_mask_39 : tensor<128x64xi1> loc(#loc881) + %ds_41 = arith.truncf %grad_scores_32 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc883) + %dq_42 = tt.trans %kT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc884) + %dq_43 = arith.constant 0.000000e+00 : f32 loc(#loc885) + %dq_44 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc885) + %dq_45 = tt.dot %ds_41, %dq_42, %dq_44, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc885) + %dq_46 = arith.addf %dq, %dq_45 : tensor<128x128xf32> loc(#loc886) + tt.return %dq_46 : tensor<128x128xf32> loc(#loc355) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc356) + tt.return %0 : tensor<128x128xf32> loc(#loc356) + } loc(#loc284) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(37,)cconstexpr_bf16__(38,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc386)), %arg_K: !tt.ptr loc("arg_K"(#loc386)), %arg_V: !tt.ptr loc("arg_V"(#loc386)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc386)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc386)), %arg_DO: !tt.ptr loc("arg_DO"(#loc386)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc386)), %arg_DV: !tt.ptr loc("arg_DV"(#loc386)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc386)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc386)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc386)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc386)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc386)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc386)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc386)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc386)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc386)), %ks0: i32 loc("ks0"(#loc386)), %ks1: i32 loc("ks1"(#loc386)), %Q: !tt.ptr loc("Q"(#loc386)), %DO: !tt.ptr loc("DO"(#loc386)), %DELTA: !tt.ptr loc("DELTA"(#loc386)), %LSE: !tt.ptr loc("LSE"(#loc386)), %dk: tensor<128x128xf32> loc("dk"(#loc386)), %dv: tensor<128x128xf32> loc("dv"(#loc386)), %k: tensor<128x128xbf16> loc("k"(#loc386)), %v: tensor<128x128xbf16> loc("v"(#loc386)), %off_z: i32 loc("off_z"(#loc386)), %off_hq: i32 loc("off_hq"(#loc386)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc386)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc386)), %stride_qm: i32 loc("stride_qm"(#loc386)), %stride_qd: i32 loc("stride_qd"(#loc386)), %stride_dom: i32 loc("stride_dom"(#loc386)), %stride_dod: i32 loc("stride_dod"(#loc386)), %q_indices: !tt.ptr loc("q_indices"(#loc386)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc386))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc948) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc949) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc950) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc951) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc951) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc952) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc952) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc953) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc954) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc954) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc955) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc955) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc955) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc956) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc957) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc957) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc958) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc958) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc959) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc960) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc960) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc961) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc961) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc961) + %hi = arith.constant 2 : i32 loc(#loc962) + %hi_20 = arith.constant 2 : i32 loc(#loc962) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc962) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks0) : (i32) -> i32 loc(#loc963) + %hi_23 = arith.constant 1 : i32 loc(#loc964) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc964) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc965) + %c0_i32 = arith.constant 0 : i32 loc(#loc405) + %c1_i32 = arith.constant 1 : i32 loc(#loc405) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc405) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc405) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc405) + %3 = ub.poison : i32 loc(#loc405) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(41,)cconstexpr_bf16__(42,)cconstexpr_1_d_44269504__(43,)cconstexpr_False_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %ks0, %ks1, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc406) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc967) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc968) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc969) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc969) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc970) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc971) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc971) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc972) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc972) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc413) + } loc(#loc1103) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc414) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc415) + %5 = ub.poison : tensor<128x128xf32> loc(#loc415) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc415) + } loc(#loc386) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(41,)cconstexpr_bf16__(42,)cconstexpr_1_d_44269504__(43,)cconstexpr_False_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc416)), %arg_K: !tt.ptr loc("arg_K"(#loc416)), %arg_V: !tt.ptr loc("arg_V"(#loc416)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc416)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc416)), %arg_DO: !tt.ptr loc("arg_DO"(#loc416)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc416)), %arg_DV: !tt.ptr loc("arg_DV"(#loc416)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc416)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc416)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc416)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc416)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc416)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc416)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc416)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc416)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc416)), %ks0: i32 loc("ks0"(#loc416)), %ks1: i32 loc("ks1"(#loc416)), %dk: tensor<128x128xf32> loc("dk"(#loc416)), %dv: tensor<128x128xf32> loc("dv"(#loc416)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc416)), %k: tensor<128x128xbf16> loc("k"(#loc416)), %v: tensor<128x128xbf16> loc("v"(#loc416)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc416)), %DELTA: !tt.ptr loc("DELTA"(#loc416)), %LSE: !tt.ptr loc("LSE"(#loc416)), %Q_LEN: i32 loc("Q_LEN"(#loc416)), %KV_LEN: i32 loc("KV_LEN"(#loc416)), %off_z: i32 loc("off_z"(#loc416)), %off_hq: i32 loc("off_hq"(#loc416)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc416)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc416)), %offs_k: tensor<128xi32> loc("offs_k"(#loc416)), %offs_v: tensor<128xi32> loc("offs_v"(#loc416)), %stride_qm: i32 loc("stride_qm"(#loc416)), %stride_qd: i32 loc("stride_qd"(#loc416)), %stride_dom: i32 loc("stride_dom"(#loc416)), %stride_dod: i32 loc("stride_dod"(#loc416)), %q_indices: !tt.ptr loc("q_indices"(#loc416)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc416))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc1014) + %lse = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1015) + %lse_0 = arith.cmpi slt, %offs_m1, %lse : tensor<64xi32> loc(#loc1015) + %lse_1 = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1016) + %lse_2 = tt.addptr %lse_1, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1016) + %lse_3 = tt.load %lse_2, %lse_0 : tensor<64x!tt.ptr> loc(#loc1017) + %lse_4 = arith.constant 0xFF800000 : f32 loc(#loc1018) + %lse_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1018) + %lse_6 = arith.cmpf oeq, %lse_3, %lse_5 : tensor<64xf32> loc(#loc1018) + %lse_7 = arith.constant 0.000000e+00 : f32 loc(#loc1019) + %lse_8 = arith.constant 0.000000e+00 : f32 loc(#loc1019) + %lse_9 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1019) + %lse_10 = arith.select %lse_6, %lse_9, %lse_3 : tensor<64xi1>, tensor<64xf32> loc(#loc1019) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc1020) + %qkT_11 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1020) + %qkT_12 = tt.dot %k, %qT, %qkT_11, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1020) + %qkT_13 = arith.constant 0.0883883461 : f32 loc(#loc1021) + %qkT_14 = arith.constant 0.0883883461 : f32 loc(#loc1021) + %qkT_15 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1021) + %qkT_16 = arith.mulf %qkT_12, %qkT_15 : tensor<128x64xf32> loc(#loc1021) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1022) + %m_17 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S1_64S_i32__(%m, %Q_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc1023) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc1024) + %n_18 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S128_1S_i32__(%n, %KV_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc1025) + %post_mod_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1026) + %post_mod_scores_19 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1027) + %post_mod_scores_20 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_19 : tensor<1x64xi32> loc(#loc1027) + %post_mod_scores_21 = arith.constant 0xFF800000 : f32 loc(#loc1028) + %post_mod_scores_22 = arith.constant 0xFF800000 : f32 loc(#loc1028) + %post_mod_scores_23 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1028) + %post_mod_scores_24 = tt.broadcast %post_mod_scores_20 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1028) + %post_mod_scores_25 = arith.select %post_mod_scores_24, %qkT_16, %post_mod_scores_23 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1028) + %tmp24 = arith.constant 0 : i32 loc(#loc1029) + %tmp24_26 = arith.constant dense<0> : tensor<1xi32> loc(#loc1029) + %tmp25 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1030) + %tmp25_27 = tt.broadcast %tmp25 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1030) + %tmp25_28 = arith.cmpi slt, %m_17, %tmp25_27 : tensor<1x64xi32> loc(#loc1030) + %tmp27 = tt.broadcast %n_18 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc1031) + %tmp27_29 = tt.broadcast %m_17 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc1031) + %tmp27_30 = arith.cmpi sle, %tmp27, %tmp27_29 : tensor<128x64xi32> loc(#loc1031) + %tmp28 = tt.broadcast %tmp25_28 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1032) + %tmp28_31 = arith.andi %tmp28, %tmp27_30 : tensor<128x64xi1> loc(#loc1032) + %tmp29 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1033) + %tmp29_32 = tt.broadcast %tmp29 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1033) + %tmp29_33 = arith.cmpi sge, %m_17, %tmp29_32 : tensor<1x64xi32> loc(#loc1033) + %tmp30 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1034) + %tmp30_34 = tt.broadcast %tmp30 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1034) + %tmp30_35 = arith.cmpi slt, %n_18, %tmp30_34 : tensor<128x1xi32> loc(#loc1034) + %tmp31 = tt.broadcast %tmp29_33 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1035) + %tmp31_36 = tt.broadcast %tmp30_35 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc1035) + %tmp31_37 = arith.andi %tmp31, %tmp31_36 : tensor<128x64xi1> loc(#loc1035) + %tmp32 = arith.constant 0 : i32 loc(#loc1036) + %tmp32_38 = arith.extui %tmp30_35 : tensor<128x1xi1> to tensor<128x1xi32> loc(#loc1036) + %tmp32_39 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1036) + %tmp32_40 = arith.cmpi eq, %tmp32_38, %tmp32_39 : tensor<128x1xi32> loc(#loc1036) + %tmp33 = tt.broadcast %tmp29_33 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1037) + %tmp33_41 = tt.broadcast %tmp32_40 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc1037) + %tmp33_42 = arith.andi %tmp33, %tmp33_41 : tensor<128x64xi1> loc(#loc1037) + %tmp34 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1038) + %tmp34_43 = tt.broadcast %tmp34 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1038) + %tmp34_44 = arith.subi %m_17, %tmp34_43 : tensor<1x64xi32> loc(#loc1038) + %tmp35 = arith.constant 16 : i32 loc(#loc1039) + %tmp35_45 = arith.constant dense<16> : tensor<1xi32> loc(#loc1039) + %tmp36 = arith.constant 0 : i32 loc(#loc1040) + %tmp36_46 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1040) + %tmp36_47 = arith.cmpi slt, %tmp34_44, %tmp36_46 : tensor<1x64xi32> loc(#loc1040) + %tmp36_48 = arith.constant 0 : i32 loc(#loc1041) + %tmp36_49 = arith.constant dense<0> : tensor<1xi32> loc(#loc1041) + %tmp36_50 = arith.cmpi slt, %tmp35_45, %tmp36_49 : tensor<1xi32> loc(#loc1041) + %tmp36_51 = tt.expand_dims %tmp36_50 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc1042) + %tmp36_52 = tt.broadcast %tmp36_51 : tensor<1x1xi1> -> tensor<1x64xi1> loc(#loc1042) + %tmp36_53 = arith.cmpi ne, %tmp36_47, %tmp36_52 : tensor<1x64xi1> loc(#loc1042) + %tmp36_54 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1043) + %tmp36_55 = tt.broadcast %tmp36_54 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1043) + %tmp36_56 = arith.remsi %tmp34_44, %tmp36_55 : tensor<1x64xi32> loc(#loc1043) + %tmp36_57 = arith.constant 0 : i32 loc(#loc1044) + %tmp36_58 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1044) + %tmp36_59 = arith.cmpi ne, %tmp36_56, %tmp36_58 : tensor<1x64xi32> loc(#loc1044) + %tmp36_60 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1045) + %tmp36_61 = tt.broadcast %tmp36_60 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1045) + %tmp36_62 = arith.divsi %tmp34_44, %tmp36_61 : tensor<1x64xi32> loc(#loc1045) + %tmp36_63 = arith.constant 1 : i32 loc(#loc1046) + %tmp36_64 = arith.constant 1 : i32 loc(#loc1046) + %tmp36_65 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc1046) + %tmp36_66 = arith.subi %tmp36_62, %tmp36_65 : tensor<1x64xi32> loc(#loc1046) + %tmp36_67 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1047) + %tmp36_68 = tt.broadcast %tmp36_67 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1047) + %tmp36_69 = arith.divsi %tmp34_44, %tmp36_68 : tensor<1x64xi32> loc(#loc1047) + %tmp36_70 = arith.select %tmp36_59, %tmp36_66, %tmp36_69 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc1048) + %tmp36_71 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1049) + %tmp36_72 = tt.broadcast %tmp36_71 : tensor<1x1xi32> -> tensor<1x64xi32> loc(#loc1049) + %tmp36_73 = arith.divsi %tmp34_44, %tmp36_72 : tensor<1x64xi32> loc(#loc1049) + %tmp36_74 = arith.select %tmp36_53, %tmp36_70, %tmp36_73 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc1050) + %tmp37 = tt.expand_dims %tmp24_26 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1051) + %tmp37_75 = tt.broadcast %tmp37 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1051) + %tmp37_76 = arith.subi %n_18, %tmp37_75 : tensor<128x1xi32> loc(#loc1051) + %tmp38 = arith.constant 0 : i32 loc(#loc1052) + %tmp38_77 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1052) + %tmp38_78 = arith.cmpi slt, %tmp37_76, %tmp38_77 : tensor<128x1xi32> loc(#loc1052) + %tmp38_79 = arith.constant 0 : i32 loc(#loc1053) + %tmp38_80 = arith.constant dense<0> : tensor<1xi32> loc(#loc1053) + %tmp38_81 = arith.cmpi slt, %tmp35_45, %tmp38_80 : tensor<1xi32> loc(#loc1053) + %tmp38_82 = tt.expand_dims %tmp38_81 {axis = 0 : i32} : tensor<1xi1> -> tensor<1x1xi1> loc(#loc1054) + %tmp38_83 = tt.broadcast %tmp38_82 : tensor<1x1xi1> -> tensor<128x1xi1> loc(#loc1054) + %tmp38_84 = arith.cmpi ne, %tmp38_78, %tmp38_83 : tensor<128x1xi1> loc(#loc1054) + %tmp38_85 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1055) + %tmp38_86 = tt.broadcast %tmp38_85 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1055) + %tmp38_87 = arith.remsi %tmp37_76, %tmp38_86 : tensor<128x1xi32> loc(#loc1055) + %tmp38_88 = arith.constant 0 : i32 loc(#loc1056) + %tmp38_89 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1056) + %tmp38_90 = arith.cmpi ne, %tmp38_87, %tmp38_89 : tensor<128x1xi32> loc(#loc1056) + %tmp38_91 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1057) + %tmp38_92 = tt.broadcast %tmp38_91 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1057) + %tmp38_93 = arith.divsi %tmp37_76, %tmp38_92 : tensor<128x1xi32> loc(#loc1057) + %tmp38_94 = arith.constant 1 : i32 loc(#loc1058) + %tmp38_95 = arith.constant 1 : i32 loc(#loc1058) + %tmp38_96 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc1058) + %tmp38_97 = arith.subi %tmp38_93, %tmp38_96 : tensor<128x1xi32> loc(#loc1058) + %tmp38_98 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1059) + %tmp38_99 = tt.broadcast %tmp38_98 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1059) + %tmp38_100 = arith.divsi %tmp37_76, %tmp38_99 : tensor<128x1xi32> loc(#loc1059) + %tmp38_101 = arith.select %tmp38_90, %tmp38_97, %tmp38_100 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc1060) + %tmp38_102 = tt.expand_dims %tmp35_45 {axis = 0 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc1061) + %tmp38_103 = tt.broadcast %tmp38_102 : tensor<1x1xi32> -> tensor<128x1xi32> loc(#loc1061) + %tmp38_104 = arith.divsi %tmp37_76, %tmp38_103 : tensor<128x1xi32> loc(#loc1061) + %tmp38_105 = arith.select %tmp38_84, %tmp38_101, %tmp38_104 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc1062) + %tmp39 = tt.broadcast %tmp36_74 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc1063) + %tmp39_106 = tt.broadcast %tmp38_105 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc1063) + %tmp39_107 = arith.cmpi eq, %tmp39, %tmp39_106 : tensor<128x64xi32> loc(#loc1063) + %tmp40 = arith.andi %tmp33_42, %tmp39_107 : tensor<128x64xi1> loc(#loc1064) + %tmp41 = arith.ori %tmp31_37, %tmp40 : tensor<128x64xi1> loc(#loc1065) + %tmp42 = arith.ori %tmp28_31, %tmp41 : tensor<128x64xi1> loc(#loc1066) + %post_mod_scores_108 = arith.constant 0xFF800000 : f32 loc(#loc1067) + %post_mod_scores_109 = arith.constant 0xFF800000 : f32 loc(#loc1067) + %post_mod_scores_110 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1067) + %post_mod_scores_111 = arith.select %tmp42, %post_mod_scores_25, %post_mod_scores_110 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1067) + %post_mod_scores_112 = arith.constant 1.44269502 : f32 loc(#loc1068) + %post_mod_scores_113 = arith.constant 1.44269502 : f32 loc(#loc1068) + %post_mod_scores_114 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1068) + %post_mod_scores_115 = arith.mulf %post_mod_scores_111, %post_mod_scores_114 : tensor<128x64xf32> loc(#loc1068) + %pT = tt.expand_dims %lse_10 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1069) + %pT_116 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1070) + %pT_117 = arith.subf %post_mod_scores_115, %pT_116 : tensor<128x64xf32> loc(#loc1070) + %pT_118 = math.exp2 %pT_117 : tensor<128x64xf32> loc(#loc1071) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1072) + %dv_119 = arith.truncf %pT_118 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1073) + %dv_120 = arith.constant 0.000000e+00 : f32 loc(#loc1074) + %dv_121 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1074) + %dv_122 = tt.dot %dv_119, %do, %dv_121, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1074) + %dv_123 = arith.addf %dv, %dv_122 : tensor<128x128xf32> loc(#loc1075) + %Di = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1076) + %Di_124 = arith.cmpi slt, %offs_m1, %Di : tensor<64xi32> loc(#loc1076) + %Di_125 = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1077) + %Di_126 = tt.addptr %Di_125, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1077) + %Di_127 = tt.load %Di_126, %Di_124 : tensor<64x!tt.ptr> loc(#loc1078) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1079) + %dpT_128 = arith.constant 0.000000e+00 : f32 loc(#loc1080) + %dpT_129 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1080) + %dpT_130 = tt.dot %v, %dpT, %dpT_129, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1080) + %dsT = tt.expand_dims %Di_127 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1081) + %dsT_131 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1082) + %dsT_132 = arith.subf %dpT_130, %dsT_131 : tensor<128x64xf32> loc(#loc1082) + %dsT_133 = arith.mulf %pT_118, %dsT_132 : tensor<128x64xf32> loc(#loc1083) + %grad_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1084) + %grad_scores_134 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1085) + %grad_scores_135 = arith.cmpi slt, %grad_scores, %grad_scores_134 : tensor<1x64xi32> loc(#loc1085) + %grad_scores_136 = arith.constant 0.000000e+00 : f32 loc(#loc1086) + %grad_scores_137 = arith.constant 0.000000e+00 : f32 loc(#loc1086) + %grad_scores_138 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1086) + %grad_scores_139 = tt.broadcast %grad_scores_135 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1086) + %grad_scores_140 = arith.select %grad_scores_139, %dsT_133, %grad_scores_138 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1086) + %dsT_141 = arith.constant 0.000000e+00 : f32 loc(#loc1087) + %dsT_142 = arith.constant 0.000000e+00 : f32 loc(#loc1087) + %dsT_143 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1087) + %dsT_144 = arith.select %tmp42, %grad_scores_140, %dsT_143 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1087) + %dk_145 = arith.truncf %dsT_144 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1088) + %dk_146 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1089) + %dk_147 = arith.constant 0.000000e+00 : f32 loc(#loc1090) + %dk_148 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1090) + %dk_149 = tt.dot %dk_145, %dk_146, %dk_148, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1090) + %dk_150 = arith.addf %dk, %dk_149 : tensor<128x128xf32> loc(#loc1091) + tt.return %dk_150, %dv_123 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc495) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc496) + %1 = ub.poison : tensor<128x128xf32> loc(#loc496) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc496) + } loc(#loc416) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%ptr: tensor<64x128x!tt.ptr> loc("ptr"(#loc242)), %offs_m: tensor<64xi32> loc("offs_m"(#loc242)), %offs_n: tensor<128xi32> loc("offs_n"(#loc242)), %M_LEN: i32 loc("M_LEN"(#loc242))) -> tensor<64x128xbf16> attributes {noinline = false} { + %0 = tt.expand_dims %offs_m {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc249) + %1 = tt.splat %M_LEN : i32 -> tensor<64x1xi32> loc(#loc250) + %2 = arith.cmpi slt, %0, %1 : tensor<64x1xi32> loc(#loc250) + %cst = arith.constant 0.000000e+00 : f32 loc(#loc251) + %3 = tt.broadcast %2 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc251) + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x128xf32> loc(#loc251) + %4 = arith.truncf %cst_0 : tensor<64x128xf32> to tensor<64x128xbf16> loc(#loc251) + %5 = tt.load %ptr, %3, %4 : tensor<64x128x!tt.ptr> loc(#loc251) + tt.return %5 : tensor<64x128xbf16> loc(#loc252) + ^bb1: // no predecessors + %6 = ub.poison : tensor<64x128xbf16> loc(#loc253) + tt.return %6 : tensor<64x128xbf16> loc(#loc253) + } loc(#loc242) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_inner__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_Pbf16_Pbf16_Pfp32_Pfp32_fp32S128_128S_fp32S128_128S_bf16S128_128S_bf16S128_128S_i32_i32_i32S128S_i32S64S_i32_i32_i32_i32_Pi32_i32__(37,)cconstexpr_bf16__(38,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc386)), %arg_K: !tt.ptr loc("arg_K"(#loc386)), %arg_V: !tt.ptr loc("arg_V"(#loc386)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc386)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc386)), %arg_DO: !tt.ptr loc("arg_DO"(#loc386)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc386)), %arg_DV: !tt.ptr loc("arg_DV"(#loc386)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc386)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc386)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc386)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc386)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc386)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc386)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc386)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc386)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc386)), %ks0: i32 loc("ks0"(#loc386)), %ks1: i32 loc("ks1"(#loc386)), %Q: !tt.ptr loc("Q"(#loc386)), %DO: !tt.ptr loc("DO"(#loc386)), %DELTA: !tt.ptr loc("DELTA"(#loc386)), %LSE: !tt.ptr loc("LSE"(#loc386)), %dk: tensor<128x128xf32> loc("dk"(#loc386)), %dv: tensor<128x128xf32> loc("dv"(#loc386)), %k: tensor<128x128xbf16> loc("k"(#loc386)), %v: tensor<128x128xbf16> loc("v"(#loc386)), %off_z: i32 loc("off_z"(#loc386)), %off_hq: i32 loc("off_hq"(#loc386)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc386)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc386)), %stride_qm: i32 loc("stride_qm"(#loc386)), %stride_qd: i32 loc("stride_qd"(#loc386)), %stride_dom: i32 loc("stride_dom"(#loc386)), %stride_dod: i32 loc("stride_dod"(#loc386)), %q_indices: !tt.ptr loc("q_indices"(#loc386)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc386))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc948) + %offs_v = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc949) + %qT_ptrs = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc950) + %qT_ptrs_0 = tt.splat %stride_qm : i32 -> tensor<1x64xi32> loc(#loc951) + %qT_ptrs_1 = arith.muli %qT_ptrs, %qT_ptrs_0 : tensor<1x64xi32> loc(#loc951) + %qT_ptrs_2 = tt.splat %Q : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc952) + %qT_ptrs_3 = tt.addptr %qT_ptrs_2, %qT_ptrs_1 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc952) + %qT_ptrs_4 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc953) + %qT_ptrs_5 = tt.splat %stride_qd : i32 -> tensor<128x1xi32> loc(#loc954) + %qT_ptrs_6 = arith.muli %qT_ptrs_4, %qT_ptrs_5 : tensor<128x1xi32> loc(#loc954) + %qT_ptrs_7 = tt.broadcast %qT_ptrs_3 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc955) + %qT_ptrs_8 = tt.broadcast %qT_ptrs_6 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc955) + %qT_ptrs_9 = tt.addptr %qT_ptrs_7, %qT_ptrs_8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc955) + %do_ptrs = tt.expand_dims %offs_m1 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc956) + %do_ptrs_10 = tt.splat %stride_dom : i32 -> tensor<64x1xi32> loc(#loc957) + %do_ptrs_11 = arith.muli %do_ptrs, %do_ptrs_10 : tensor<64x1xi32> loc(#loc957) + %do_ptrs_12 = tt.splat %DO : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc958) + %do_ptrs_13 = tt.addptr %do_ptrs_12, %do_ptrs_11 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc958) + %do_ptrs_14 = tt.expand_dims %offs_v {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc959) + %do_ptrs_15 = tt.splat %stride_dod : i32 -> tensor<1x128xi32> loc(#loc960) + %do_ptrs_16 = arith.muli %do_ptrs_14, %do_ptrs_15 : tensor<1x128xi32> loc(#loc960) + %do_ptrs_17 = tt.broadcast %do_ptrs_13 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc961) + %do_ptrs_18 = tt.broadcast %do_ptrs_16 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc961) + %do_ptrs_19 = tt.addptr %do_ptrs_17, %do_ptrs_18 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc961) + %hi = arith.constant 2 : i32 loc(#loc962) + %hi_20 = arith.constant 2 : i32 loc(#loc962) + %hi_21 = arith.muli %sparse_q_num_blocks, %hi_20 : i32 loc(#loc962) + %hi_22 = tt.call @"triton.language.standard.cdiv__i32__(1,)cconstexpr_64_"(%ks0) : (i32) -> i32 loc(#loc963) + %hi_23 = arith.constant 1 : i32 loc(#loc964) + %hi_24 = arith.maxsi %hi_22, %hi_23 : i32 loc(#loc964) + %hi_25 = arith.minsi %hi_21, %hi_24 : i32 loc(#loc965) + %c0_i32 = arith.constant 0 : i32 loc(#loc405) + %c1_i32 = arith.constant 1 : i32 loc(#loc405) + %0 = arith.bitcast %c0_i32 : i32 to i32 loc(#loc405) + %1 = arith.bitcast %hi_25 : i32 to i32 loc(#loc405) + %2 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc405) + %3 = ub.poison : i32 loc(#loc405) + %do_ptrs_26:5 = scf.for %start_m = %0 to %1 step %2 iter_args(%dk_27 = %dk, %dv_28 = %dv, %offs_m1_29 = %offs_m1, %qT_ptrs_30 = %qT_ptrs_9, %do_ptrs_31 = %do_ptrs_19) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %6:2 = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(41,)cconstexpr_bf16__(42,)cconstexpr_1_d_44269504__(43,)cconstexpr_True_"(%arg_Q, %arg_K, %arg_V, %arg_LSE, %arg_DELTA, %arg_DO, %arg_DQ, %arg_DV, %arg_KV_NUM_BLKS, %arg_KV_IDX, %arg_Q_NUM_BLKS, %arg_Q_IDX, %arg_FULL_KV_NUM_BLKS, %arg_FULL_KV_IDX, %arg_FULL_Q_NUM_BLKS, %arg_FULL_Q_IDX, %out_ptr0, %ks0, %ks1, %dk_27, %dv_28, %qT_ptrs_30, %k, %v, %do_ptrs_31, %DELTA, %LSE, %ks0, %ks1, %off_z, %off_hq, %offs_n1, %offs_m1_29, %offs_k, %offs_v, %stride_qm, %stride_qd, %stride_dom, %stride_dod, %q_indices, %sparse_q_num_blocks) : (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, i32, i32, tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x64x!tt.ptr>, tensor<128x128xbf16>, tensor<128x128xbf16>, tensor<64x128x!tt.ptr>, !tt.ptr, !tt.ptr, i32, i32, i32, i32, tensor<128xi32>, tensor<64xi32>, tensor<128xi32>, tensor<128xi32>, i32, i32, i32, i32, !tt.ptr, i32) -> (tensor<128x128xf32>, tensor<128x128xf32>) loc(#loc406) + %offset = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_offset_for_next_block__i32_Pi32_i32__(3,)cconstexpr_128__(4,)cconstexpr_2__(5,)cconstexpr_64__(6,)cconstexpr_False_"(%start_m, %q_indices, %sparse_q_num_blocks) : (i32, !tt.ptr, i32) -> i32 loc(#loc967) + %qT_ptrs_32 = arith.muli %offset, %stride_qm : i32 loc(#loc968) + %qT_ptrs_33 = tt.splat %qT_ptrs_32 : i32 -> tensor<128x64xi32> loc(#loc969) + %qT_ptrs_34 = tt.addptr %qT_ptrs_30, %qT_ptrs_33 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc969) + %do_ptrs_35 = arith.muli %offset, %stride_dom : i32 loc(#loc970) + %do_ptrs_36 = tt.splat %do_ptrs_35 : i32 -> tensor<64x128xi32> loc(#loc971) + %do_ptrs_37 = tt.addptr %do_ptrs_31, %do_ptrs_36 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc971) + %offs_m1_38 = tt.splat %offset : i32 -> tensor<64xi32> loc(#loc972) + %offs_m1_39 = arith.addi %offs_m1_29, %offs_m1_38 : tensor<64xi32> loc(#loc972) + scf.yield %6#0, %6#1, %offs_m1_39, %qT_ptrs_34, %do_ptrs_37 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc413) + } loc(#loc1103) + tt.return %do_ptrs_26#0, %do_ptrs_26#1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc414) + ^bb1: // no predecessors + %4 = ub.poison : tensor<128x128xf32> loc(#loc415) + %5 = ub.poison : tensor<128x128xf32> loc(#loc415) + tt.return %4, %5 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc415) + } loc(#loc386) + tt.func private @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.bwd_dkdv_block_mn__Pbf16_Pbf16_Pbf16_Pfp32_Pfp32_Pbf16_Pbf16_Pbf16_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pi32_Pbf16_i32_i32_fp32S128_128S_fp32S128_128S_Pbf16S128_64S_bf16S128_128S_bf16S128_128S_Pbf16S64_128S_Pfp32_Pfp32_i32_i32_i32_i32_i32S128S_i32S64S_i32S128S_i32S128S_i32_i32_i32_i32_Pi32_i32__(41,)cconstexpr_bf16__(42,)cconstexpr_1_d_44269504__(43,)cconstexpr_True_"(%arg_Q: !tt.ptr loc("arg_Q"(#loc416)), %arg_K: !tt.ptr loc("arg_K"(#loc416)), %arg_V: !tt.ptr loc("arg_V"(#loc416)), %arg_LSE: !tt.ptr loc("arg_LSE"(#loc416)), %arg_DELTA: !tt.ptr loc("arg_DELTA"(#loc416)), %arg_DO: !tt.ptr loc("arg_DO"(#loc416)), %arg_DQ: !tt.ptr loc("arg_DQ"(#loc416)), %arg_DV: !tt.ptr loc("arg_DV"(#loc416)), %arg_KV_NUM_BLKS: !tt.ptr loc("arg_KV_NUM_BLKS"(#loc416)), %arg_KV_IDX: !tt.ptr loc("arg_KV_IDX"(#loc416)), %arg_Q_NUM_BLKS: !tt.ptr loc("arg_Q_NUM_BLKS"(#loc416)), %arg_Q_IDX: !tt.ptr loc("arg_Q_IDX"(#loc416)), %arg_FULL_KV_NUM_BLKS: !tt.ptr loc("arg_FULL_KV_NUM_BLKS"(#loc416)), %arg_FULL_KV_IDX: !tt.ptr loc("arg_FULL_KV_IDX"(#loc416)), %arg_FULL_Q_NUM_BLKS: !tt.ptr loc("arg_FULL_Q_NUM_BLKS"(#loc416)), %arg_FULL_Q_IDX: !tt.ptr loc("arg_FULL_Q_IDX"(#loc416)), %out_ptr0: !tt.ptr loc("out_ptr0"(#loc416)), %ks0: i32 loc("ks0"(#loc416)), %ks1: i32 loc("ks1"(#loc416)), %dk: tensor<128x128xf32> loc("dk"(#loc416)), %dv: tensor<128x128xf32> loc("dv"(#loc416)), %qT_ptrs: tensor<128x64x!tt.ptr> loc("qT_ptrs"(#loc416)), %k: tensor<128x128xbf16> loc("k"(#loc416)), %v: tensor<128x128xbf16> loc("v"(#loc416)), %do_ptrs: tensor<64x128x!tt.ptr> loc("do_ptrs"(#loc416)), %DELTA: !tt.ptr loc("DELTA"(#loc416)), %LSE: !tt.ptr loc("LSE"(#loc416)), %Q_LEN: i32 loc("Q_LEN"(#loc416)), %KV_LEN: i32 loc("KV_LEN"(#loc416)), %off_z: i32 loc("off_z"(#loc416)), %off_hq: i32 loc("off_hq"(#loc416)), %offs_n1: tensor<128xi32> loc("offs_n1"(#loc416)), %offs_m1: tensor<64xi32> loc("offs_m1"(#loc416)), %offs_k: tensor<128xi32> loc("offs_k"(#loc416)), %offs_v: tensor<128xi32> loc("offs_v"(#loc416)), %stride_qm: i32 loc("stride_qm"(#loc416)), %stride_qd: i32 loc("stride_qd"(#loc416)), %stride_dom: i32 loc("stride_dom"(#loc416)), %stride_dod: i32 loc("stride_dod"(#loc416)), %q_indices: !tt.ptr loc("q_indices"(#loc416)), %sparse_q_num_blocks: i32 loc("sparse_q_num_blocks"(#loc416))) -> (tensor<128x128xf32>, tensor<128x128xf32>) attributes {noinline = false} { + %qT = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S128_64S_i32S128S_i32S64S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_True__(6,)cconstexpr_False__(7,)cconstexpr_128_"(%qT_ptrs, %offs_k, %offs_m1, %Q_LEN) : (tensor<128x64x!tt.ptr>, tensor<128xi32>, tensor<64xi32>, i32) -> tensor<128x64xbf16> loc(#loc1014) + %lse = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1015) + %lse_0 = arith.cmpi slt, %offs_m1, %lse : tensor<64xi32> loc(#loc1015) + %lse_1 = tt.splat %LSE : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1016) + %lse_2 = tt.addptr %lse_1, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1016) + %lse_3 = tt.load %lse_2, %lse_0 : tensor<64x!tt.ptr> loc(#loc1017) + %lse_4 = arith.constant 0xFF800000 : f32 loc(#loc1018) + %lse_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1018) + %lse_6 = arith.cmpf oeq, %lse_3, %lse_5 : tensor<64xf32> loc(#loc1018) + %lse_7 = arith.constant 0.000000e+00 : f32 loc(#loc1019) + %lse_8 = arith.constant 0.000000e+00 : f32 loc(#loc1019) + %lse_9 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1019) + %lse_10 = arith.select %lse_6, %lse_9, %lse_3 : tensor<64xi1>, tensor<64xf32> loc(#loc1019) + %qkT = arith.constant 0.000000e+00 : f32 loc(#loc1020) + %qkT_11 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1020) + %qkT_12 = tt.dot %k, %qT, %qkT_11, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1020) + %qkT_13 = arith.constant 0.0883883461 : f32 loc(#loc1021) + %qkT_14 = arith.constant 0.0883883461 : f32 loc(#loc1021) + %qkT_15 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1021) + %qkT_16 = arith.mulf %qkT_12, %qkT_15 : tensor<128x64xf32> loc(#loc1021) + %m = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1022) + %m_17 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S1_64S_i32__(%m, %Q_LEN) : (tensor<1x64xi32>, i32) -> tensor<1x64xi32> loc(#loc1023) + %n = tt.expand_dims %offs_n1 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc1024) + %n_18 = tt.call @torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.get_bounded_indices__i32S128_1S_i32__(%n, %KV_LEN) : (tensor<128x1xi32>, i32) -> tensor<128x1xi32> loc(#loc1025) + %post_mod_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1026) + %post_mod_scores_19 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1027) + %post_mod_scores_20 = arith.cmpi slt, %post_mod_scores, %post_mod_scores_19 : tensor<1x64xi32> loc(#loc1027) + %post_mod_scores_21 = arith.constant 0xFF800000 : f32 loc(#loc1028) + %post_mod_scores_22 = arith.constant 0xFF800000 : f32 loc(#loc1028) + %post_mod_scores_23 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1028) + %post_mod_scores_24 = tt.broadcast %post_mod_scores_20 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1028) + %post_mod_scores_25 = arith.select %post_mod_scores_24, %qkT_16, %post_mod_scores_23 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1028) + %post_mod_scores_26 = arith.constant 1.44269502 : f32 loc(#loc1068) + %post_mod_scores_27 = arith.constant 1.44269502 : f32 loc(#loc1068) + %post_mod_scores_28 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1068) + %post_mod_scores_29 = arith.mulf %post_mod_scores_25, %post_mod_scores_28 : tensor<128x64xf32> loc(#loc1068) + %pT = tt.expand_dims %lse_10 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1069) + %pT_30 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1070) + %pT_31 = arith.subf %post_mod_scores_29, %pT_30 : tensor<128x64xf32> loc(#loc1070) + %pT_32 = math.exp2 %pT_31 : tensor<128x64xf32> loc(#loc1071) + %do = tt.call @"torch._inductor.runtime.compile_tasks.cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.load_checked_2d__Pbf16S64_128S_i32S64S_i32S128S_i32__(3,)cconstexpr_None__(4,)cconstexpr_None__(5,)cconstexpr_False__(6,)cconstexpr_True__(8,)cconstexpr_128_"(%do_ptrs, %offs_m1, %offs_v, %Q_LEN) : (tensor<64x128x!tt.ptr>, tensor<64xi32>, tensor<128xi32>, i32) -> tensor<64x128xbf16> loc(#loc1072) + %dv_33 = arith.truncf %pT_32 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1073) + %dv_34 = arith.constant 0.000000e+00 : f32 loc(#loc1074) + %dv_35 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1074) + %dv_36 = tt.dot %dv_33, %do, %dv_35, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1074) + %dv_37 = arith.addf %dv, %dv_36 : tensor<128x128xf32> loc(#loc1075) + %Di = tt.splat %Q_LEN : i32 -> tensor<64xi32> loc(#loc1076) + %Di_38 = arith.cmpi slt, %offs_m1, %Di : tensor<64xi32> loc(#loc1076) + %Di_39 = tt.splat %DELTA : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc1077) + %Di_40 = tt.addptr %Di_39, %offs_m1 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc1077) + %Di_41 = tt.load %Di_40, %Di_38 : tensor<64x!tt.ptr> loc(#loc1078) + %dpT = tt.trans %do {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc1079) + %dpT_42 = arith.constant 0.000000e+00 : f32 loc(#loc1080) + %dpT_43 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1080) + %dpT_44 = tt.dot %v, %dpT, %dpT_43, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc1080) + %dsT = tt.expand_dims %Di_41 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc1081) + %dsT_45 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc1082) + %dsT_46 = arith.subf %dpT_44, %dsT_45 : tensor<128x64xf32> loc(#loc1082) + %dsT_47 = arith.mulf %pT_32, %dsT_46 : tensor<128x64xf32> loc(#loc1083) + %grad_scores = tt.expand_dims %offs_m1 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc1084) + %grad_scores_48 = tt.splat %Q_LEN : i32 -> tensor<1x64xi32> loc(#loc1085) + %grad_scores_49 = arith.cmpi slt, %grad_scores, %grad_scores_48 : tensor<1x64xi32> loc(#loc1085) + %grad_scores_50 = arith.constant 0.000000e+00 : f32 loc(#loc1086) + %grad_scores_51 = arith.constant 0.000000e+00 : f32 loc(#loc1086) + %grad_scores_52 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1086) + %grad_scores_53 = tt.broadcast %grad_scores_49 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc1086) + %grad_scores_54 = arith.select %grad_scores_53, %dsT_47, %grad_scores_52 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc1086) + %dk_55 = arith.truncf %grad_scores_54 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc1088) + %dk_56 = tt.trans %qT {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc1089) + %dk_57 = arith.constant 0.000000e+00 : f32 loc(#loc1090) + %dk_58 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1090) + %dk_59 = tt.dot %dk_55, %dk_56, %dk_58, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc1090) + %dk_60 = arith.addf %dk, %dk_59 : tensor<128x128xf32> loc(#loc1091) + tt.return %dk_60, %dv_37 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc495) + ^bb1: // no predecessors + %0 = ub.poison : tensor<128x128xf32> loc(#loc496) + %1 = ub.poison : tensor<128x128xf32> loc(#loc496) + tt.return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc496) + } loc(#loc416) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":94:54) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":94:49) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":95:54) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":95:49) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":96:54) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":96:49) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:74) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:66) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:100) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:91) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:82) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:59) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:126) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:118) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:152) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:143) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:134) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:111) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:53) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":99:58) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":99:53) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":100:58) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":100:53) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":102:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":103:9) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":104:10) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":106:10) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":111:24) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":112:36) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":113:34) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":115:27) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":116:28) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":117:23) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":119:15) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":120:16) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":122:28) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:25) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:47) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:35) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:59) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":125:25) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":125:47) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":125:35) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":125:59) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:27) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:50) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:37) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:61) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":131:9) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":132:9) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":133:10) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":135:14) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":136:26) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":137:26) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:14) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:7) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":140:24) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":142:29) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":143:30) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:29) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:54) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:44) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":145:35) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":146:41) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":147:31) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":148:26) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":149:26) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":151:35) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":152:42) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":152:54) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":154:55) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":154:78) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":155:50) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":155:83) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":155:68) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:30) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:52) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:40) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:63) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:32) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:55) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:42) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:66) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":160:32) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":160:55) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":160:42) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":160:66) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:30) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:35) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:46) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:56) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":163:17) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":164:19) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":167:19) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":168:21) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":169:25) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":172:22) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":174:36) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":175:42) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":175:29) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":178:107) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":179:111) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:58) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:34) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:25) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:57) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:33) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:26) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:30) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:50) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":191:18) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":195:30) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:27) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:41) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:53) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:39) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:42) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:29) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":207:12) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":214:39) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:31) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:45) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:62) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:43) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":218:46) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":218:33) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":226:16) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:32) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:43) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:24) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:63) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:74) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:56) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":232:14) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:48) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:59) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:76) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:87) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:69) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:30) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":239:29) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":240:30) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":242:26) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":244:30) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":245:25) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":246:25) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":249:22) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":250:22) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":252:25) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":253:42) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":253:29) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":256:107) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":257:107) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":262:30) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:32) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:51) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:34) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:56) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:44) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:67) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:36) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:59) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:46) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:70) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":268:36) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":268:59) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":268:46) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":268:70) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:34) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:39) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:50) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:60) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":271:21) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":272:23) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":275:25) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":276:29) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":278:39) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":279:46) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":279:58) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":281:58) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":281:80) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":282:53) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":282:81) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":282:70) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":286:32) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:30) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:43) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:55) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:42) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:45) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:32) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":298:16) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":306:41) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:34) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:47) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:64) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:46) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":310:49) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":310:36) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":318:20) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":303:12) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:31) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:42) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:23) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:62) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:73) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:55) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":325:26) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":326:25) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":327:25) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:50) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:71) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:61) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:30) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":334:14) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":337:29) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:31) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:27) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:45) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:53) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:41) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:64) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:71) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":344:59) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:59) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:55) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:74) +#loc228 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:69) +#loc229 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:29) +#loc230 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:99) +#loc231 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:4) +#loc233 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:16) +#loc234 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc235 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc236 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:11) +#loc237 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:4) +#loc238 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":118:0) +#loc239 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:31) +#loc240 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:11) +#loc241 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":127:4) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:27) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:38) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:20) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:56) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:67) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:49) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:41) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:52) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:23) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:15) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":792:4) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":387:26) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":388:26) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:26) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:37) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:18) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:56) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:67) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:49) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:26) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:37) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:18) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:56) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:67) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:49) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:43) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:90) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:101) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:63) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":397:28) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":405:12) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":411:64) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:28) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:19) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":415:28) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":415:19) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":417:19) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":417:8) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":419:11) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":419:4) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":458:105) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":459:19) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":461:14) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":464:36) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":464:46) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":467:36) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":467:46) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":476:43) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":476:54) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":476:79) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":480:31) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":481:22) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":483:23) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":484:22) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":485:23) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":486:22) +#loc301 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":487:22) +#loc302 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":488:24) +#loc303 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":489:23) +#loc304 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":490:23) +#loc305 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":491:33) +#loc306 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:34) +#loc307 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:49) +#loc308 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:41) +#loc309 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:70) +#loc310 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:79) +#loc311 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:91) +#loc312 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:99) +#loc313 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:111) +#loc314 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:102) +#loc315 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:128) +#loc316 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:119) +#loc317 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":493:23) +#loc318 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:34) +#loc319 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:49) +#loc320 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:41) +#loc321 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:70) +#loc322 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:79) +#loc323 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:91) +#loc324 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:99) +#loc325 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:111) +#loc326 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:102) +#loc327 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:128) +#loc328 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:119) +#loc329 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":495:25) +#loc330 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":496:24) +#loc331 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":497:23) +#loc332 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":498:23) +#loc333 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":503:69) +#loc334 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":506:27) +#loc335 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:39) +#loc336 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:21) +#loc337 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":510:104) +#loc338 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":512:20) +#loc339 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:22) +#loc340 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:19) +#loc341 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:14) +#loc342 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":520:39) +#loc343 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":520:50) +#loc344 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":520:71) +#loc345 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":524:32) +#loc346 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":524:43) +#loc347 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":524:62) +#loc348 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":524:73) +#loc349 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":524:54) +#loc350 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":531:43) +#loc351 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":533:15) +#loc352 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:30) +#loc353 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:21) +#loc354 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:10) +#loc355 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":537:11) +#loc356 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":537:4) +#loc357 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:41) +#loc358 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:52) +#loc359 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:23) +#loc360 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:15) +#loc362 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":762:21) +#loc363 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":762:11) +#loc364 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":762:4) +#loc366 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":752:33) +#loc367 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:38) +#loc368 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:24) +#loc369 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:109) +#loc370 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:113) +#loc371 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:39) +#loc372 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:55) +#loc373 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:25) +#loc374 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:30) +#loc375 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:35) +#loc376 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:60) +#loc377 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:34) +#loc378 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:48) +#loc379 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:63) +#loc380 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:29) +#loc381 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:47) +#loc382 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:61) +#loc383 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:42) +#loc384 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":758:11) +#loc385 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":758:4) +#loc387 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":580:26) +#loc388 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":581:26) +#loc389 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:26) +#loc390 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:37) +#loc391 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:18) +#loc392 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:56) +#loc393 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:67) +#loc394 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:49) +#loc395 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:27) +#loc396 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:38) +#loc397 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:19) +#loc398 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:58) +#loc399 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:69) +#loc400 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:51) +#loc401 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:42) +#loc402 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:87) +#loc403 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:98) +#loc404 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:61) +#loc405 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":592:28) +#loc406 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":600:12) +#loc407 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":605:62) +#loc408 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:28) +#loc409 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:19) +#loc410 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:28) +#loc411 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:19) +#loc412 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":610:19) +#loc413 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":610:8) +#loc414 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":612:11) +#loc415 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":612:4) +#loc417 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":651:105) +#loc418 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:52) +#loc419 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:28) +#loc420 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:22) +#loc421 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:26) +#loc422 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:46) +#loc423 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":658:20) +#loc424 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":660:15) +#loc425 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":662:36) +#loc426 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":662:46) +#loc427 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":665:36) +#loc428 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":665:46) +#loc429 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":674:43) +#loc430 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":674:54) +#loc431 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":674:78) +#loc432 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":678:32) +#loc433 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":679:24) +#loc434 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":681:25) +#loc435 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":682:24) +#loc436 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":683:25) +#loc437 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":684:24) +#loc438 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":685:24) +#loc439 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":686:25) +#loc440 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":687:24) +#loc441 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":688:24) +#loc442 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":689:33) +#loc443 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:34) +#loc444 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:49) +#loc445 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:41) +#loc446 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:70) +#loc447 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:79) +#loc448 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:91) +#loc449 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:99) +#loc450 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:111) +#loc451 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:102) +#loc452 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:128) +#loc453 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:119) +#loc454 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":691:24) +#loc455 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:34) +#loc456 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:49) +#loc457 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:41) +#loc458 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:70) +#loc459 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:79) +#loc460 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:91) +#loc461 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:99) +#loc462 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:111) +#loc463 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:102) +#loc464 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:128) +#loc465 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:119) +#loc466 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":693:25) +#loc467 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":694:24) +#loc468 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":695:24) +#loc469 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":696:24) +#loc470 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":700:69) +#loc471 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":703:27) +#loc472 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:44) +#loc473 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:40) +#loc474 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:22) +#loc475 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":705:99) +#loc476 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:24) +#loc477 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:43) +#loc478 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:10) +#loc479 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:53) +#loc480 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:29) +#loc481 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:21) +#loc482 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:29) +#loc483 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:20) +#loc484 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:25) +#loc485 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:22) +#loc486 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:16) +#loc487 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":723:39) +#loc488 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":723:50) +#loc489 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":723:70) +#loc490 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":737:45) +#loc491 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:24) +#loc492 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:52) +#loc493 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:43) +#loc494 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:10) +#loc495 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":741:11) +#loc496 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":741:4) +#loc516 = loc("ZQ"(#loc24)) +#loc517 = loc("HQ"(#loc25)) +#loc518 = loc("HKV"(#loc26)) +#loc519 = loc("ZKV"(#loc27)) +#loc520 = loc("pid"(#loc28)) +#loc521 = loc("NUM_KV_BLOCKS"(#loc29)) +#loc522 = loc("NUM_Q_BLOCKS"(#loc30)) +#loc523 = loc("off_zq"(#loc31)) +#loc524 = loc("off_hkv"(#loc32)) +#loc525 = loc("off_zkv"(#loc33)) +#loc526 = loc("SPARSE_Z"(#loc34)) +#loc527 = loc("SPARSE_HQ"(#loc35)) +#loc528 = loc("sparse_idx_z"(#loc36)) +#loc529 = loc("k_adj"(#loc37)) +#loc530 = loc("k_adj"(#loc38)) +#loc531 = loc("k_adj"(#loc39)) +#loc532 = loc("k_adj"(#loc40)) +#loc533 = loc("v_adj"(#loc41)) +#loc534 = loc("v_adj"(#loc42)) +#loc535 = loc("v_adj"(#loc43)) +#loc536 = loc("v_adj"(#loc44)) +#loc537 = loc("dv_adj"(#loc45)) +#loc538 = loc("dv_adj"(#loc46)) +#loc539 = loc("dv_adj"(#loc47)) +#loc540 = loc("dv_adj"(#loc48)) +#loc541 = loc("K"(#loc49)) +#loc542 = loc("V"(#loc50)) +#loc543 = loc("DV"(#loc51)) +#loc544 = loc("RCP_LN2"(#loc52)) +#loc545 = loc("offs_k"(#loc53)) +#loc546 = loc("offs_v"(#loc54)) +#loc547 = loc("off_pid"(#loc57)) +#loc548 = loc("SPARSE_Q_MULTIPLE"(#loc58)) +#loc549 = loc("SPARSE_KV_MULTIPLE"(#loc59)) +#loc550 = loc("off_hq2"(#loc60)) +#loc551 = loc("off_hq2"(#loc61)) +#loc552 = loc("off_hq2"(#loc62)) +#loc553 = loc("start_m2_block"(#loc63)) +#loc554 = loc("off_pid_mask"(#loc64)) +#loc555 = loc("stride_kv_num_blks_h"(#loc65)) +#loc556 = loc("stride_kv_idx_h"(#loc66)) +#loc557 = loc("stride_kv_idx_m"(#loc67)) +#loc558 = loc("sparse_idx_hq2"(#loc68)) +#loc559 = loc("sparse_hz_offset"(#loc69)) +#loc560 = loc("sparse_hz_offset"(#loc70)) +#loc561 = loc("sparse_kv_num_blks_offset"(#loc71)) +#loc562 = loc("sparse_kv_num_blks_offset"(#loc72)) +#loc563 = loc("sparse_kv_idx_offset"(#loc73)) +#loc564 = loc("sparse_kv_idx_offset"(#loc74)) +#loc565 = loc("sparse_kv_idx_offset"(#loc75)) +#loc566 = loc("q_adj2"(#loc76)) +#loc567 = loc("q_adj2"(#loc77)) +#loc568 = loc("q_adj2"(#loc78)) +#loc569 = loc("q_adj2"(#loc79)) +#loc570 = loc("do_adj2"(#loc80)) +#loc571 = loc("do_adj2"(#loc81)) +#loc572 = loc("do_adj2"(#loc82)) +#loc573 = loc("do_adj2"(#loc83)) +#loc574 = loc("dq_adj2"(#loc84)) +#loc575 = loc("dq_adj2"(#loc85)) +#loc576 = loc("dq_adj2"(#loc86)) +#loc577 = loc("dq_adj2"(#loc87)) +#loc578 = loc("off_chz2"(#loc88)) +#loc579 = loc("off_chz2"(#loc89)) +#loc580 = loc("off_chz2"(#loc90)) +#loc581 = loc("off_chz2"(#loc91)) +#loc582 = loc("Q2"(#loc92)) +#loc583 = loc("DO2"(#loc93)) +#loc584 = loc("DQ2"(#loc94)) +#loc585 = loc("LSE2"(#loc95)) +#loc586 = loc("DELTA2"(#loc96)) +#loc587 = loc("dq"(#loc97)) +#loc588 = loc("start_m2"(#loc98)) +#loc589 = loc("offs_m2"(#loc99)) +#loc590 = loc("offs_m2"(#loc100)) +#loc591 = loc("q"(#loc101)) +#loc592 = loc("do"(#loc102)) +#loc593 = loc("Di"(#loc103)) +#loc594 = loc("Di"(#loc104)) +#loc595 = loc("Di"(#loc105)) +#loc596 = loc("lse"(#loc106)) +#loc597 = loc("lse"(#loc107)) +#loc598 = loc("lse"(#loc108)) +#loc599 = loc("lse"(#loc109)) +#loc600 = loc("lse"(#loc110)) +#loc601 = loc("lse"(#loc111)) +#loc602 = loc("kv_indices"(#loc112)) +#loc603 = loc("kv_start"(#loc113)) +#loc604 = loc("kv_start"(#loc114)) +#loc605 = loc("sparse_kv_num_blocks"(#loc115)) +#loc606 = loc("sparse_kv_num_blocks"(#loc116)) +#loc607 = loc("offs_n2"(#loc117)) +#loc608 = loc("offs_n2"(#loc118)) +#loc609 = loc("dq"(#loc119)) +#loc610 = loc("kv_indices"(#loc120)) +#loc611 = loc("kv_start"(#loc121)) +#loc612 = loc("kv_start"(#loc122)) +#loc613 = loc("sparse_kv_num_blocks"(#loc123)) +#loc614 = loc("sparse_kv_num_blocks"(#loc124)) +#loc615 = loc("offs_n2"(#loc125)) +#loc616 = loc("offs_n2"(#loc126)) +#loc617 = loc("dq"(#loc127)) +#loc618 = loc("dq_ptrs"(#loc128)) +#loc619 = loc("dq_ptrs"(#loc129)) +#loc620 = loc("dq_ptrs"(#loc130)) +#loc621 = loc("dq_ptrs"(#loc131)) +#loc622 = loc("dq_ptrs"(#loc132)) +#loc623 = loc("dq_ptrs"(#loc133)) +#loc624 = loc("dq"(#loc134)) +#loc625 = loc("SPARSE_Q_MULTIPLE"(#loc141)) +#loc626 = loc("SPARSE_KV_MULTIPLE"(#loc142)) +#loc627 = loc("pid_mask"(#loc143)) +#loc628 = loc("stride_q_num_blks_h"(#loc144)) +#loc629 = loc("stride_q_idx_h"(#loc145)) +#loc630 = loc("stride_q_idx_n"(#loc146)) +#loc631 = loc("dv"(#loc147)) +#loc632 = loc("dk"(#loc148)) +#loc633 = loc("start_n1"(#loc149)) +#loc634 = loc("offs_n1"(#loc150)) +#loc635 = loc("offs_n1"(#loc151)) +#loc636 = loc("k"(#loc152)) +#loc637 = loc("v"(#loc153)) +#loc638 = loc("dv"(#loc154)) +#loc639 = loc("off_hq1"(#loc155)) +#loc640 = loc("off_hq1"(#loc156)) +#loc641 = loc("q_adj1"(#loc157)) +#loc642 = loc("q_adj1"(#loc158)) +#loc643 = loc("q_adj1"(#loc159)) +#loc644 = loc("q_adj1"(#loc160)) +#loc645 = loc("do_adj1"(#loc161)) +#loc646 = loc("do_adj1"(#loc162)) +#loc647 = loc("do_adj1"(#loc163)) +#loc648 = loc("do_adj1"(#loc164)) +#loc649 = loc("dq_adj1"(#loc165)) +#loc650 = loc("dq_adj1"(#loc166)) +#loc651 = loc("dq_adj1"(#loc167)) +#loc652 = loc("dq_adj1"(#loc168)) +#loc653 = loc("off_chz1"(#loc169)) +#loc654 = loc("off_chz1"(#loc170)) +#loc655 = loc("off_chz1"(#loc171)) +#loc656 = loc("off_chz1"(#loc172)) +#loc657 = loc("Q1"(#loc173)) +#loc658 = loc("DO1"(#loc174)) +#loc659 = loc("LSE1"(#loc175)) +#loc660 = loc("DELTA1"(#loc176)) +#loc661 = loc("sparse_idx_hq1"(#loc177)) +#loc662 = loc("sparse_hz_offset"(#loc178)) +#loc663 = loc("sparse_hz_offset"(#loc179)) +#loc664 = loc("sparse_q_num_blks_offset"(#loc180)) +#loc665 = loc("sparse_q_num_blks_offset"(#loc181)) +#loc666 = loc("sparse_q_idx_offset"(#loc182)) +#loc667 = loc("sparse_q_idx_offset"(#loc183)) +#loc668 = loc("sparse_q_idx_offset"(#loc184)) +#loc669 = loc("q_indices"(#loc185)) +#loc670 = loc("q_start"(#loc186)) +#loc671 = loc("q_start"(#loc187)) +#loc672 = loc("sparse_q_num_blocks"(#loc188)) +#loc673 = loc("sparse_q_num_blocks"(#loc189)) +#loc674 = loc("offs_m1"(#loc190)) +#loc675 = loc("offs_m1"(#loc191)) +#loc676 = loc("q_indices"(#loc193)) +#loc677 = loc("q_start"(#loc194)) +#loc678 = loc("q_start"(#loc195)) +#loc679 = loc("sparse_q_num_blocks"(#loc196)) +#loc680 = loc("sparse_q_num_blocks"(#loc197)) +#loc681 = loc("offs_m1"(#loc198)) +#loc682 = loc("offs_m1"(#loc199)) +#loc683 = loc("dv_ptrs"(#loc202)) +#loc684 = loc("dv_ptrs"(#loc203)) +#loc685 = loc("dv_ptrs"(#loc204)) +#loc686 = loc("dv_ptrs"(#loc205)) +#loc687 = loc("dv_ptrs"(#loc206)) +#loc688 = loc("dv_ptrs"(#loc207)) +#loc689 = loc("index_n"(#loc208)) +#loc690 = loc("index_k"(#loc209)) +#loc691 = loc("index_v"(#loc210)) +#loc692 = loc("dk"(#loc215)) +#loc693 = loc("mask"(#loc216)) +#loc694 = loc("xindex"(#loc217)) +#loc695 = loc("xindex"(#loc218)) +#loc696 = loc("xindex"(#loc219)) +#loc697 = loc("xindex"(#loc220)) +#loc698 = loc("xindex"(#loc221)) +#loc699 = loc("xindex"(#loc222)) +#loc700 = loc("xindex"(#loc223)) +#loc701 = loc("xindex"(#loc224)) +#loc709 = loc("ptr"(#loc243)) +#loc710 = loc("ptr"(#loc244)) +#loc711 = loc("ptr"(#loc245)) +#loc712 = loc("ptr"(#loc246)) +#loc713 = loc("ptr"(#loc247)) +#loc714 = loc("ptr"(#loc248)) +#loc751 = loc("offs_k"(#loc255)) +#loc752 = loc("offs_v"(#loc256)) +#loc753 = loc("kT_ptrs"(#loc257)) +#loc754 = loc("kT_ptrs"(#loc258)) +#loc755 = loc("kT_ptrs"(#loc259)) +#loc756 = loc("kT_ptrs"(#loc260)) +#loc757 = loc("kT_ptrs"(#loc261)) +#loc758 = loc("kT_ptrs"(#loc262)) +#loc759 = loc("vT_ptrs"(#loc263)) +#loc760 = loc("vT_ptrs"(#loc264)) +#loc761 = loc("vT_ptrs"(#loc265)) +#loc762 = loc("vT_ptrs"(#loc266)) +#loc763 = loc("vT_ptrs"(#loc267)) +#loc764 = loc("vT_ptrs"(#loc268)) +#loc765 = loc("hi"(#loc269)) +#loc766 = loc("hi"(#loc270)) +#loc767 = loc("hi"(#loc271)) +#loc768 = loc("hi"(#loc272)) +#loc769 = loc("dq"(#loc273)) +#loc770 = loc("dq"(#loc274)) +#loc771 = loc("offset"(#loc275)) +#loc772 = loc("kT_ptrs"(#loc276)) +#loc773 = loc("kT_ptrs"(#loc277)) +#loc774 = loc("vT_ptrs"(#loc278)) +#loc775 = loc("vT_ptrs"(#loc279)) +#loc776 = loc("offs_n2"(#loc280)) +#loc817 = loc("kT"(#loc285)) +#loc818 = loc("qk"(#loc286)) +#loc819 = loc("qk"(#loc287)) +#loc820 = loc("n"(#loc288)) +#loc821 = loc("n"(#loc289)) +#loc822 = loc("m"(#loc290)) +#loc823 = loc("m"(#loc291)) +#loc824 = loc("post_mod_scores"(#loc292)) +#loc825 = loc("post_mod_scores"(#loc293)) +#loc826 = loc("post_mod_scores"(#loc294)) +#loc827 = loc("tmp2"(#loc295)) +#loc828 = loc("tmp3"(#loc296)) +#loc829 = loc("tmp5"(#loc297)) +#loc830 = loc("tmp6"(#loc298)) +#loc831 = loc("tmp7"(#loc299)) +#loc832 = loc("tmp8"(#loc300)) +#loc833 = loc("tmp9"(#loc301)) +#loc834 = loc("tmp10"(#loc302)) +#loc835 = loc("tmp11"(#loc303)) +#loc836 = loc("tmp12"(#loc304)) +#loc837 = loc("tmp13"(#loc305)) +#loc838 = loc("tmp14"(#loc306)) +#loc839 = loc("tmp14"(#loc307)) +#loc840 = loc("tmp14"(#loc308)) +#loc841 = loc("tmp14"(#loc309)) +#loc842 = loc("tmp14"(#loc310)) +#loc843 = loc("tmp14"(#loc311)) +#loc844 = loc("tmp14"(#loc312)) +#loc845 = loc("tmp14"(#loc313)) +#loc846 = loc("tmp14"(#loc314)) +#loc847 = loc("tmp14"(#loc315)) +#loc848 = loc("tmp14"(#loc316)) +#loc849 = loc("tmp15"(#loc317)) +#loc850 = loc("tmp16"(#loc318)) +#loc851 = loc("tmp16"(#loc319)) +#loc852 = loc("tmp16"(#loc320)) +#loc853 = loc("tmp16"(#loc321)) +#loc854 = loc("tmp16"(#loc322)) +#loc855 = loc("tmp16"(#loc323)) +#loc856 = loc("tmp16"(#loc324)) +#loc857 = loc("tmp16"(#loc325)) +#loc858 = loc("tmp16"(#loc326)) +#loc859 = loc("tmp16"(#loc327)) +#loc860 = loc("tmp16"(#loc328)) +#loc861 = loc("tmp17"(#loc329)) +#loc862 = loc("tmp18"(#loc330)) +#loc863 = loc("tmp19"(#loc331)) +#loc864 = loc("tmp20"(#loc332)) +#loc865 = loc("post_mod_scores"(#loc333)) +#loc866 = loc("post_mod_scores"(#loc334)) +#loc867 = loc("p"(#loc335)) +#loc868 = loc("p"(#loc336)) +#loc869 = loc("vT"(#loc337)) +#loc870 = loc("dp"(#loc338)) +#loc871 = loc("ds"(#loc339)) +#loc872 = loc("ds"(#loc340)) +#loc873 = loc("ds"(#loc341)) +#loc874 = loc("grad_scores"(#loc342)) +#loc875 = loc("grad_scores"(#loc343)) +#loc876 = loc("grad_scores"(#loc344)) +#loc877 = loc("scatter_mask"(#loc345)) +#loc878 = loc("scatter_mask"(#loc346)) +#loc879 = loc("scatter_mask"(#loc347)) +#loc880 = loc("scatter_mask"(#loc348)) +#loc881 = loc("scatter_mask"(#loc349)) +#loc882 = loc("ds"(#loc350)) +#loc883 = loc("ds"(#loc351)) +#loc884 = loc("dq"(#loc352)) +#loc885 = loc("dq"(#loc353)) +#loc886 = loc("dq"(#loc354)) +#loc893 = loc("cur_block_idx"(#loc366)) +#loc894 = loc("cur_block"(#loc367)) +#loc895 = loc("cur_block"(#loc368)) +#loc896 = loc("next_block"(#loc369)) +#loc897 = loc("next_block"(#loc370)) +#loc898 = loc("next_block"(#loc371)) +#loc899 = loc("next_block"(#loc372)) +#loc900 = loc("next_block"(#loc373)) +#loc901 = loc("needs_jump"(#loc374)) +#loc902 = loc("needs_jump"(#loc375)) +#loc903 = loc("needs_jump"(#loc376)) +#loc904 = loc("jump_to_block"(#loc377)) +#loc905 = loc("jump_to_block"(#loc378)) +#loc906 = loc("jump_to_block"(#loc379)) +#loc907 = loc("offset"(#loc380)) +#loc908 = loc("offset"(#loc381)) +#loc909 = loc("offset"(#loc382)) +#loc910 = loc("offset"(#loc383)) +#loc948 = loc("offs_k"(#loc387)) +#loc949 = loc("offs_v"(#loc388)) +#loc950 = loc("qT_ptrs"(#loc389)) +#loc951 = loc("qT_ptrs"(#loc390)) +#loc952 = loc("qT_ptrs"(#loc391)) +#loc953 = loc("qT_ptrs"(#loc392)) +#loc954 = loc("qT_ptrs"(#loc393)) +#loc955 = loc("qT_ptrs"(#loc394)) +#loc956 = loc("do_ptrs"(#loc395)) +#loc957 = loc("do_ptrs"(#loc396)) +#loc958 = loc("do_ptrs"(#loc397)) +#loc959 = loc("do_ptrs"(#loc398)) +#loc960 = loc("do_ptrs"(#loc399)) +#loc961 = loc("do_ptrs"(#loc400)) +#loc962 = loc("hi"(#loc401)) +#loc963 = loc("hi"(#loc402)) +#loc964 = loc("hi"(#loc403)) +#loc965 = loc("hi"(#loc404)) +#loc966 = loc("dk"(#loc405)) +#loc967 = loc("offset"(#loc407)) +#loc968 = loc("qT_ptrs"(#loc408)) +#loc969 = loc("qT_ptrs"(#loc409)) +#loc970 = loc("do_ptrs"(#loc410)) +#loc971 = loc("do_ptrs"(#loc411)) +#loc972 = loc("offs_m1"(#loc412)) +#loc1014 = loc("qT"(#loc417)) +#loc1015 = loc("lse"(#loc418)) +#loc1016 = loc("lse"(#loc419)) +#loc1017 = loc("lse"(#loc420)) +#loc1018 = loc("lse"(#loc421)) +#loc1019 = loc("lse"(#loc422)) +#loc1020 = loc("qkT"(#loc423)) +#loc1021 = loc("qkT"(#loc424)) +#loc1022 = loc("m"(#loc425)) +#loc1023 = loc("m"(#loc426)) +#loc1024 = loc("n"(#loc427)) +#loc1025 = loc("n"(#loc428)) +#loc1026 = loc("post_mod_scores"(#loc429)) +#loc1027 = loc("post_mod_scores"(#loc430)) +#loc1028 = loc("post_mod_scores"(#loc431)) +#loc1029 = loc("tmp24"(#loc432)) +#loc1030 = loc("tmp25"(#loc433)) +#loc1031 = loc("tmp27"(#loc434)) +#loc1032 = loc("tmp28"(#loc435)) +#loc1033 = loc("tmp29"(#loc436)) +#loc1034 = loc("tmp30"(#loc437)) +#loc1035 = loc("tmp31"(#loc438)) +#loc1036 = loc("tmp32"(#loc439)) +#loc1037 = loc("tmp33"(#loc440)) +#loc1038 = loc("tmp34"(#loc441)) +#loc1039 = loc("tmp35"(#loc442)) +#loc1040 = loc("tmp36"(#loc443)) +#loc1041 = loc("tmp36"(#loc444)) +#loc1042 = loc("tmp36"(#loc445)) +#loc1043 = loc("tmp36"(#loc446)) +#loc1044 = loc("tmp36"(#loc447)) +#loc1045 = loc("tmp36"(#loc448)) +#loc1046 = loc("tmp36"(#loc449)) +#loc1047 = loc("tmp36"(#loc450)) +#loc1048 = loc("tmp36"(#loc451)) +#loc1049 = loc("tmp36"(#loc452)) +#loc1050 = loc("tmp36"(#loc453)) +#loc1051 = loc("tmp37"(#loc454)) +#loc1052 = loc("tmp38"(#loc455)) +#loc1053 = loc("tmp38"(#loc456)) +#loc1054 = loc("tmp38"(#loc457)) +#loc1055 = loc("tmp38"(#loc458)) +#loc1056 = loc("tmp38"(#loc459)) +#loc1057 = loc("tmp38"(#loc460)) +#loc1058 = loc("tmp38"(#loc461)) +#loc1059 = loc("tmp38"(#loc462)) +#loc1060 = loc("tmp38"(#loc463)) +#loc1061 = loc("tmp38"(#loc464)) +#loc1062 = loc("tmp38"(#loc465)) +#loc1063 = loc("tmp39"(#loc466)) +#loc1064 = loc("tmp40"(#loc467)) +#loc1065 = loc("tmp41"(#loc468)) +#loc1066 = loc("tmp42"(#loc469)) +#loc1067 = loc("post_mod_scores"(#loc470)) +#loc1068 = loc("post_mod_scores"(#loc471)) +#loc1069 = loc("pT"(#loc472)) +#loc1070 = loc("pT"(#loc473)) +#loc1071 = loc("pT"(#loc474)) +#loc1072 = loc("do"(#loc475)) +#loc1073 = loc("dv"(#loc476)) +#loc1074 = loc("dv"(#loc477)) +#loc1075 = loc("dv"(#loc478)) +#loc1076 = loc("Di"(#loc479)) +#loc1077 = loc("Di"(#loc480)) +#loc1078 = loc("Di"(#loc481)) +#loc1079 = loc("dpT"(#loc482)) +#loc1080 = loc("dpT"(#loc483)) +#loc1081 = loc("dsT"(#loc484)) +#loc1082 = loc("dsT"(#loc485)) +#loc1083 = loc("dsT"(#loc486)) +#loc1084 = loc("grad_scores"(#loc487)) +#loc1085 = loc("grad_scores"(#loc488)) +#loc1086 = loc("grad_scores"(#loc489)) +#loc1087 = loc("dsT"(#loc490)) +#loc1088 = loc("dk"(#loc491)) +#loc1089 = loc("dk"(#loc492)) +#loc1090 = loc("dk"(#loc493)) +#loc1091 = loc("dk"(#loc494)) +#loc1092 = loc("SPARSE_Q_MULTIPLE"(#loc548)) +#loc1093 = loc("SPARSE_KV_MULTIPLE"(#loc549)) +#loc1094 = loc("SPARSE_Q_MULTIPLE"(#loc625)) +#loc1095 = loc("SPARSE_KV_MULTIPLE"(#loc626)) +#loc1096 = loc("dk"(#loc638)) +#loc1097 = loc("offs_n2"(#loc769)) +#loc1098 = loc("dv"(#loc966)) +#loc1099 = loc("kT_ptrs"(#loc1097)) +#loc1100 = loc("offs_m1"(#loc1098)) +#loc1101 = loc("vT_ptrs"(#loc1099)) +#loc1102 = loc("qT_ptrs"(#loc1100)) +#loc1103 = loc("do_ptrs"(#loc1102)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttgir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..9d86274e84ab0a9344c0fe48f19576841e3289d8 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttgir @@ -0,0 +1,1932 @@ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":18:0) +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +#loc301 = loc("arg_Q"(#loc)) +#loc302 = loc("arg_K"(#loc)) +#loc303 = loc("arg_V"(#loc)) +#loc304 = loc("arg_LSE"(#loc)) +#loc305 = loc("arg_DELTA"(#loc)) +#loc306 = loc("arg_DO"(#loc)) +#loc307 = loc("arg_DQ"(#loc)) +#loc308 = loc("arg_DV"(#loc)) +#loc309 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc310 = loc("arg_KV_IDX"(#loc)) +#loc311 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc312 = loc("arg_Q_IDX"(#loc)) +#loc313 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc314 = loc("arg_FULL_KV_IDX"(#loc)) +#loc315 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc316 = loc("arg_FULL_Q_IDX"(#loc)) +#loc317 = loc("out_ptr0"(#loc)) +#loc318 = loc("ks0"(#loc)) +#loc319 = loc("ks1"(#loc)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<1024> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_0 = arith.constant dense<128> : tensor<1x128xi32, #blocked> loc(#loc1) + %cst_1 = arith.constant dense<4096> : tensor<128x1xi32, #blocked> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %c32_i32 = arith.constant 32 : i32 loc(#loc1) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xbf16, #blocked> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16, #blocked1> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x128xbf16, #blocked> loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_5 = arith.constant dense<0.0883883461> : tensor<128x128xf32, #mma> loc(#loc1) + %cst_6 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> loc(#loc1) + %cst_7 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_8 = arith.constant dense<0.0883883461> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_9 = arith.constant dense<0xFF800000> : tensor<128x64xf32, #mma1> loc(#loc1) + %cst_10 = arith.constant dense<1.44269502> : tensor<128x64xf32, #mma1> loc(#loc1) + %true = arith.constant true loc(#loc1) + %c-1_i32 = arith.constant -1 : i32 loc(#loc1) + %c3_i32 = arith.constant 3 : i32 loc(#loc1) + %cst_11 = arith.constant dense<65536> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_12 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc1) + %cst_13 = arith.constant dense<262144> : tensor<128x64xi32, #blocked1> loc(#loc1) + %cst_14 = arith.constant dense<8192> : tensor<64x128xi32, #blocked> loc(#loc1) + %cst_15 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_16 = arith.constant dense<64> : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc1) + %cst_17 = arith.constant dense<128> : tensor<64x1xi32, #blocked> loc(#loc1) + %cst_18 = arith.constant dense<4096> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_19 = arith.constant dense<1024> : tensor<1x64xi32, #blocked1> loc(#loc1) + %cst_20 = arith.constant dense<128> : tensor<128x1xi32, #blocked> loc(#loc1) + %cst_21 = arith.constant dense<16> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_22 = arith.constant dense<16> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_23 = arith.constant dense<0xFF800000> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_24 = arith.constant dense<0.000000e+00> : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc1) + %cst_25 = arith.constant dense<1> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_26 = arith.constant dense<1> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_27 = arith.constant dense<0> : tensor<128x1xi32, #mma1> loc(#loc1) + %cst_28 = arith.constant dense<0> : tensor<1x64xi32, #mma1> loc(#loc1) + %cst_29 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc1) + %cst_30 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc1) + %0 = arith.muli %ks0, %c4096_i32 : i32 loc(#loc2) + %1 = arith.cmpi sle, %ks0, %c1_i32 : i32 loc(#loc3) + %2 = arith.extui %1 : i1 to i32 loc(#loc4) + %3 = arith.cmpi sgt, %ks0, %c1_i32 : i32 loc(#loc5) + %4 = arith.extui %3 : i1 to i32 loc(#loc6) + %5 = arith.muli %ks0, %4 : i32 loc(#loc6) + %6 = arith.addi %2, %5 : i32 loc(#loc7) + %7 = arith.muli %6, %c4096_i32 : i32 loc(#loc8) + %8 = arith.muli %6, %c128_i32 : i32 loc(#loc9) + %9 = arith.muli %ks1, %c1024_i32 : i32 loc(#loc10) + %pid = tt.get_program_id x : i32 loc(#loc320) + %NUM_KV_BLOCKS = arith.addi %ks1, %c127_i32 : i32 loc(#loc586) + %NUM_KV_BLOCKS_31 = arith.divsi %NUM_KV_BLOCKS, %c128_i32 : i32 loc(#loc587) + %NUM_Q_BLOCKS = arith.addi %ks0, %c127_i32 : i32 loc(#loc588) + %NUM_Q_BLOCKS_32 = arith.divsi %NUM_Q_BLOCKS, %c128_i32 : i32 loc(#loc589) + %off_zq = tt.get_program_id y : i32 loc(#loc323) + %off_hkv = tt.get_program_id z : i32 loc(#loc324) + %k_adj = arith.muli %off_hkv, %c128_i32 : i32 loc(#loc325) + %k_adj_33 = arith.extsi %k_adj : i32 to i64 loc(#loc326) + %dv_adj = arith.muli %9, %off_zq : i32 loc(#loc327) + %dv_adj_34 = arith.addi %k_adj, %dv_adj : i32 loc(#loc328) + %dv_adj_35 = arith.extsi %dv_adj_34 : i32 to i64 loc(#loc329) + %K = tt.addptr %arg_K, %k_adj_33 : !tt.ptr, i64 loc(#loc330) + %V = tt.addptr %arg_V, %k_adj_33 : !tt.ptr, i64 loc(#loc331) + %DV = tt.addptr %arg_DV, %dv_adj_35 : !tt.ptr, i64 loc(#loc332) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc333) + %offs_k_36 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc333) + %10 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS_31 : i32 loc(#loc27) + scf.if %10 { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS_31 : i32 loc(#loc334) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS_32 : i32 loc(#loc335) + %off_hq2_37 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc336) + %off_hq2_38 = arith.addi %off_hq2, %off_hq2_37 : i32 loc(#loc337) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS_32 : i32 loc(#loc338) + %q_adj2 = arith.muli %off_hq2_38, %c128_i32 : i32 loc(#loc339) + %q_adj2_39 = arith.muli %0, %off_zq : i32 loc(#loc340) + %q_adj2_40 = arith.addi %q_adj2, %q_adj2_39 : i32 loc(#loc341) + %q_adj2_41 = arith.extsi %q_adj2_40 : i32 to i64 loc(#loc342) + %do_adj2 = arith.muli %8, %off_hq2_38 : i32 loc(#loc343) + %do_adj2_42 = arith.muli %7, %off_zq : i32 loc(#loc344) + %do_adj2_43 = arith.addi %do_adj2, %do_adj2_42 : i32 loc(#loc345) + %do_adj2_44 = arith.extsi %do_adj2_43 : i32 to i64 loc(#loc346) + %off_chz2 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc347) + %off_chz2_45 = arith.addi %off_chz2, %off_hq2_38 : i32 loc(#loc348) + %off_chz2_46 = arith.muli %off_chz2_45, %ks0 : i32 loc(#loc349) + %off_chz2_47 = arith.extsi %off_chz2_46 : i32 to i64 loc(#loc350) + %Q2 = tt.addptr %arg_Q, %q_adj2_41 : !tt.ptr, i64 loc(#loc351) + %DO2 = tt.addptr %arg_DO, %do_adj2_44 : !tt.ptr, i64 loc(#loc352) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_41 : !tt.ptr, i64 loc(#loc353) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_47 : !tt.ptr, i64 loc(#loc354) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_47 : !tt.ptr, i64 loc(#loc355) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc356) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc357) + %offs_m2_48 = tt.splat %start_m2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc357) + %offs_m2_49 = arith.addi %offs_m2, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc357) + %offs_m2_50 = arith.addi %offs_m2_48, %offs_k_36 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc357) + %ptr = tt.expand_dims %offs_m2_49 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc590) + %ptr_51 = tt.expand_dims %offs_m2_50 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> loc(#loc590) + %ptr_52 = arith.muli %ptr, %cst_1 : tensor<128x1xi32, #blocked> loc(#loc591) + %ptr_53 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc592) + %ptr_54 = tt.addptr %ptr_53, %ptr_52 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc592) + %ptr_55 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc593) + %ptr_56 = tt.expand_dims %ptr_55 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc593) + %ptr_57 = tt.broadcast %ptr_54 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc594) + %ptr_58 = tt.broadcast %ptr_56 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc594) + %ptr_59 = tt.addptr %ptr_57, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc594) + %q = tt.splat %ks0 : i32 -> tensor<128x1xi32, #blocked> loc(#loc595) + %q_60 = tt.splat %ks0 : i32 -> tensor<128x1xi32, #mma1> loc(#loc595) + %q_61 = arith.cmpi slt, %ptr, %q : tensor<128x1xi32, #blocked> loc(#loc595) + %q_62 = tt.broadcast %q_61 : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc596) + %q_63 = tt.load %ptr_59, %q_62, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc596) + %q_64 = ttg.local_alloc %q_63 : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc596) + %ptr_65 = arith.muli %ptr, %cst_20 : tensor<128x1xi32, #blocked> loc(#loc597) + %ptr_66 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc598) + %ptr_67 = tt.addptr %ptr_66, %ptr_65 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc598) + %ptr_68 = tt.broadcast %ptr_67 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc599) + %ptr_69 = tt.addptr %ptr_68, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc599) + %do = tt.load %ptr_69, %q_62, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc600) + %do_70 = ttg.local_alloc %do : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc600) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc365) + %Di_71 = arith.cmpi slt, %offs_m2_50, %Di : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc365) + %Di_72 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc366) + %Di_73 = tt.addptr %Di_72, %offs_m2_50 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc366) + %Di_74 = tt.load %Di_73, %Di_71 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc367) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc368) + %lse_75 = tt.addptr %lse, %offs_m2_50 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc368) + %lse_76 = tt.load %lse_75, %Di_71 : tensor<128x!tt.ptr, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc369) + %lse_77 = arith.cmpf oeq, %lse_76, %cst_30 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc370) + %lse_78 = arith.select %lse_77, %cst_29, %lse_76 : tensor<128xi1, #ttg.slice<{dim = 1, parent = #mma1}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc371) + %lse_79 = tt.expand_dims %lse_78 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> loc(#loc372) + %kv_indices = tt.addptr %arg_KV_IDX, %start_m2_block : !tt.ptr, i32 loc(#loc373) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc374) + %kv_start_80 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc375) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc376) + %sparse_kv_num_blocks_81 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc377) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc378) + %offs_n2_82 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc378) + %offs_n2_83 = tt.splat %kv_start_80 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc379) + %offs_n2_84 = tt.splat %kv_start_80 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc379) + %offs_n2_85 = arith.addi %offs_n2_83, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc379) + %offs_n2_86 = arith.addi %offs_n2_84, %offs_n2_82 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc379) + %kT_ptrs = tt.expand_dims %offs_n2_86 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc601) + %kT_ptrs_87 = arith.muli %kT_ptrs, %cst_19 : tensor<1x64xi32, #blocked1> loc(#loc602) + %kT_ptrs_88 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc603) + %kT_ptrs_89 = tt.addptr %kT_ptrs_88, %kT_ptrs_87 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc603) + %kT_ptrs_90 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc604) + %kT_ptrs_91 = tt.expand_dims %kT_ptrs_90 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc604) + %kT_ptrs_92 = tt.broadcast %kT_ptrs_89 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc605) + %kT_ptrs_93 = tt.broadcast %kT_ptrs_91 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc605) + %kT_ptrs_94 = tt.addptr %kT_ptrs_92, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc605) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc606) + %vT_ptrs_95 = tt.addptr %vT_ptrs, %kT_ptrs_87 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc606) + %vT_ptrs_96 = tt.broadcast %vT_ptrs_95 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc607) + %vT_ptrs_97 = tt.addptr %vT_ptrs_96, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc607) + %hi = arith.muli %sparse_kv_num_blocks_81, %c2_i32 : i32 loc(#loc608) + %hi_98 = arith.addi %ks1, %c63_i32 : i32 loc(#loc762) + %hi_99 = arith.divsi %hi_98, %c64_i32 : i32 loc(#loc763) + %hi_100 = arith.maxsi %hi_99, %c1_i32 : i32 loc(#loc610) + %hi_101 = arith.minsi %hi, %hi_100 : i32 loc(#loc611) + %kT = tt.splat %ks1 : i32 -> tensor<1x64xi32, #mma1> loc(#loc909) + %kT_102 = tt.splat %ks1 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc909) + %m = arith.remsi %ptr_51, %q_60 : tensor<128x1xi32, #mma1> loc(#loc910) + %tmp3 = arith.cmpi slt, %m, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc766) + %tmp5 = tt.broadcast %m : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc767) + %tmp6 = tt.broadcast %tmp3 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc768) + %tmp7 = arith.cmpi sge, %m, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc769) + %tmp9 = tt.broadcast %tmp7 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc770) + %tmp14 = arith.remsi %m, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc771) + %tmp14_103 = arith.cmpi ne, %tmp14, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc772) + %tmp14_104 = arith.divsi %m, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc773) + %tmp14_105 = arith.subi %tmp14_104, %cst_26 : tensor<128x1xi32, #mma1> loc(#loc774) + %tmp14_106 = arith.select %tmp14_103, %tmp14_105, %tmp14_104 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc775) + %tmp14_107 = arith.select %tmp3, %tmp14_106, %tmp14_104 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc776) + %tmp17 = tt.broadcast %tmp14_107 : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc777) + %p = tt.broadcast %lse_79 : tensor<128x1xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc778) + %ds = tt.expand_dims %Di_74 {axis = 1 : i32} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> loc(#loc779) + %ds_108 = tt.broadcast %ds : tensor<128x1xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc780) + %kT_109 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc911) + %vT = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc912) + %vT_ptrs_110 = arith.cmpi sgt, %hi_101, %c0_i32 : i32 loc(#loc921) + %kT_111 = arith.cmpi slt, %kT_ptrs, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc909) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc911) + %kT_113 = ttg.memdesc_index %kT_109[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %vT_ptrs_114 = tt.splat %vT_ptrs_110 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc921) + %vT_ptrs_115 = arith.andi %vT_ptrs_114, %kT_112 : tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_116 = ttg.async_copy_global_to_local %kT_ptrs_94, %kT_113 mask %vT_ptrs_115 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %kT_117 = ttg.async_commit_group tokens %kT_116 loc(#loc911) + %vT_118 = ttg.memdesc_index %vT[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_119 = ttg.async_copy_global_to_local %vT_ptrs_97, %vT_118 mask %vT_ptrs_115 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_120 = ttg.async_commit_group tokens %vT_119 loc(#loc912) + %vT_ptrs_121 = arith.cmpi sgt, %hi_101, %c1_i32 : i32 loc(#loc921) + %kT_ptrs_122 = tt.addptr %kT_ptrs_94, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc614) + %vT_ptrs_123 = tt.addptr %vT_ptrs_97, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc615) + %offs_n2_124 = arith.addi %offs_n2_86, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc616) + %kT_125 = tt.expand_dims %offs_n2_124 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc914) + %kT_126 = arith.cmpi slt, %kT_125, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc909) + %kT_127 = tt.broadcast %kT_126 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc911) + %kT_128 = ttg.memdesc_index %kT_109[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %vT_ptrs_129 = tt.splat %vT_ptrs_121 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc921) + %vT_ptrs_130 = arith.andi %vT_ptrs_129, %kT_127 : tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_131 = ttg.async_copy_global_to_local %kT_ptrs_122, %kT_128 mask %vT_ptrs_130 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %kT_132 = ttg.async_commit_group tokens %kT_131 loc(#loc911) + %vT_133 = ttg.memdesc_index %vT[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_134 = ttg.async_copy_global_to_local %vT_ptrs_123, %vT_133 mask %vT_ptrs_130 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_135 = ttg.async_commit_group tokens %vT_134 loc(#loc912) + ttng.fence_async_shared {bCluster = false} loc(#loc783) + %vT_ptrs_136:12 = scf.for %vT_ptrs_192 = %c0_i32 to %hi_101 step %c1_i32 iter_args(%arg20 = %cst_6, %kT_ptrs_193 = %kT_ptrs_122, %offs_n2_194 = %offs_n2_124, %vT_ptrs_195 = %vT_ptrs_123, %offs_n2_196 = %offs_n2_85, %arg25 = %c1_i32, %arg26 = %c-1_i32, %kT_197 = %kT_117, %kT_198 = %kT_132, %vT_199 = %vT_120, %vT_200 = %vT_135, %arg31 = %c64_i32) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32) : i32 { + %vT_ptrs_201 = arith.subi %hi_101, %c2_i32 : i32 loc(#loc921) + %vT_ptrs_202 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_201 : i32 loc(#loc921) + %vT_ptrs_203 = arith.subi %hi_101, %c1_i32 : i32 loc(#loc921) + %vT_ptrs_204 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_203 : i32 loc(#loc921) + %vT_ptrs_205 = arith.addi %arg26, %c1_i32 : i32 loc(#loc921) + %vT_ptrs_206 = arith.cmpi sge, %vT_ptrs_205, %c3_i32 : i32 loc(#loc921) + %vT_ptrs_207 = arith.select %vT_ptrs_206, %c0_i32, %vT_ptrs_205 : i32 loc(#loc921) + %kT_208 = tt.expand_dims %offs_n2_196 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc914) + %kT_209 = arith.cmpi slt, %kT_208, %kT : tensor<1x64xi32, #mma1> loc(#loc909) + %kT_210 = tt.broadcast %kT_209 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc911) + %kT_211 = ttg.async_wait %kT_197, %vT_199 {num = 2 : i32} loc(#loc911) + %kT_212 = ttg.memdesc_index %kT_109[%vT_ptrs_207] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %dq_213 = ttg.memdesc_trans %kT_212 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc784) + %qk = ttng.warp_group_dot %q_64, %kT_212, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc783) + %qk_214:3 = ttng.warp_group_dot_wait %qk, %q_64, %kT_212 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc783) + %qk_215 = arith.mulf %qk_214#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc785) + %n = arith.remsi %kT_208, %kT : tensor<1x64xi32, #mma1> loc(#loc915) + %post_mod_scores = arith.select %kT_210, %qk_215, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc787) + %tmp5_216 = tt.broadcast %n : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc767) + %tmp5_217 = arith.cmpi sle, %tmp5_216, %tmp5 : tensor<128x64xi32, #mma1> loc(#loc767) + %tmp6_218 = arith.andi %tmp6, %tmp5_217 : tensor<128x64xi1, #mma1> loc(#loc768) + %tmp8 = arith.cmpi slt, %n, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc788) + %tmp9_219 = tt.broadcast %tmp8 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc770) + %tmp9_220 = arith.andi %tmp9, %tmp9_219 : tensor<128x64xi1, #mma1> loc(#loc770) + %tmp10 = arith.extui %tmp8 : tensor<1x64xi1, #mma1> to tensor<1x64xi32, #mma1> loc(#loc789) + %tmp10_221 = arith.cmpi eq, %tmp10, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc789) + %tmp11 = tt.broadcast %tmp10_221 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc790) + %tmp11_222 = arith.andi %tmp9, %tmp11 : tensor<128x64xi1, #mma1> loc(#loc790) + %tmp16 = arith.remsi %n, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc791) + %tmp16_223 = arith.cmpi ne, %tmp16, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc792) + %tmp16_224 = arith.divsi %n, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc793) + %tmp16_225 = arith.subi %tmp16_224, %cst_25 : tensor<1x64xi32, #mma1> loc(#loc794) + %tmp16_226 = arith.select %tmp16_223, %tmp16_225, %tmp16_224 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc795) + %tmp16_227 = arith.select %tmp8, %tmp16_226, %tmp16_224 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc796) + %tmp17_228 = tt.broadcast %tmp16_227 : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc777) + %tmp17_229 = arith.cmpi eq, %tmp17, %tmp17_228 : tensor<128x64xi32, #mma1> loc(#loc777) + %tmp18 = arith.andi %tmp11_222, %tmp17_229 : tensor<128x64xi1, #mma1> loc(#loc797) + %tmp19 = arith.ori %tmp9_220, %tmp18 : tensor<128x64xi1, #mma1> loc(#loc798) + %tmp20 = arith.ori %tmp6_218, %tmp19 : tensor<128x64xi1, #mma1> loc(#loc799) + %post_mod_scores_230 = arith.select %tmp20, %post_mod_scores, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc800) + %post_mod_scores_231 = arith.mulf %post_mod_scores_230, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc801) + %p_232 = arith.subf %post_mod_scores_231, %p : tensor<128x64xf32, #mma1> loc(#loc778) + %p_233 = math.exp2 %p_232 : tensor<128x64xf32, #mma1> loc(#loc802) + %vT_234 = ttg.memdesc_index %vT[%vT_ptrs_207] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %dp = ttng.warp_group_dot %do_70, %vT_234, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc803) + %dp_235:3 = ttng.warp_group_dot_wait %dp, %do_70, %vT_234 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc803) + %ds_236 = arith.subf %dp_235#0, %ds_108 : tensor<128x64xf32, #mma1> loc(#loc780) + %ds_237 = arith.mulf %p_233, %ds_236 : tensor<128x64xf32, #mma1> loc(#loc804) + %grad_scores = arith.select %kT_210, %ds_237, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc805) + %ds_238 = arith.select %tmp20, %grad_scores, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc806) + %ds_239 = arith.truncf %ds_238 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc807) + %ds_240 = ttg.convert_layout %ds_239 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc807) + %dq_241 = ttng.warp_group_dot %ds_240, %dq_213, %arg20 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc808) + %offs_n2_242 = tt.splat %arg31 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc616) + %offs_n2_243 = arith.addi %offs_n2_196, %offs_n2_242 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc616) + %vT_ptrs_244 = arith.addi %vT_ptrs_192, %c1_i32 : i32 loc(#loc921) + %cur_block_idx = arith.divsi %vT_ptrs_244, %c2_i32 : i32 loc(#loc809) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc810) + %cur_block_245 = tt.load %cur_block, %vT_ptrs_204 evictionPolicy = evict_last : !tt.ptr loc(#loc811) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc812) + %next_block_246 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_81 : i32 loc(#loc813) + %next_block_247 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc814) + %vT_ptrs_248 = arith.andi %vT_ptrs_204, %next_block_246 : i1 loc(#loc921) + %next_block_249 = tt.load %next_block_247, %vT_ptrs_248 evictionPolicy = evict_last : !tt.ptr loc(#loc815) + %needs_jump = arith.addi %vT_ptrs_192, %c2_i32 : i32 loc(#loc816) + %needs_jump_250 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc817) + %needs_jump_251 = arith.cmpi eq, %needs_jump_250, %c0_i32 : i32 loc(#loc818) + %jump_to_block = arith.subi %next_block_249, %cur_block_245 : i32 loc(#loc819) + %jump_to_block_252 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc820) + %jump_to_block_253 = arith.subi %jump_to_block_252, %c64_i32 : i32 loc(#loc821) + %offset = arith.extui %needs_jump_251 : i1 to i32 loc(#loc822) + %offset_254 = arith.muli %jump_to_block_253, %offset : i32 loc(#loc822) + %offset_255 = arith.subi %c1_i32, %offset : i32 loc(#loc823) + %offset_256 = arith.muli %offset_255, %c64_i32 : i32 loc(#loc824) + %offset_257 = arith.addi %offset_254, %offset_256 : i32 loc(#loc825) + %kT_ptrs_258 = arith.muli %offset_257, %c1024_i32 : i32 loc(#loc618) + %kT_ptrs_259 = tt.splat %kT_ptrs_258 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc614) + %kT_ptrs_260 = tt.addptr %kT_ptrs_193, %kT_ptrs_259 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc614) + %vT_ptrs_261 = tt.addptr %vT_ptrs_195, %kT_ptrs_259 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc615) + %offs_n2_262 = tt.splat %offset_257 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc616) + %offs_n2_263 = arith.addi %offs_n2_194, %offs_n2_262 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc616) + %vT_ptrs_264 = arith.addi %arg25, %c1_i32 : i32 loc(#loc921) + %vT_ptrs_265 = arith.cmpi sge, %vT_ptrs_264, %c3_i32 : i32 loc(#loc921) + %vT_ptrs_266 = arith.select %vT_ptrs_265, %c0_i32, %vT_ptrs_264 : i32 loc(#loc921) + %kT_267 = tt.expand_dims %offs_n2_263 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc914) + %kT_268 = arith.cmpi slt, %kT_267, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc909) + %kT_269 = tt.broadcast %kT_268 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc911) + %kT_270 = ttg.memdesc_index %kT_109[%vT_ptrs_266] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %vT_ptrs_271 = tt.splat %vT_ptrs_202 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc921) + %vT_ptrs_272 = arith.andi %vT_ptrs_271, %kT_269 : tensor<128x64xi1, #blocked1> loc(#loc921) + %kT_273 = ttg.async_copy_global_to_local %kT_ptrs_260, %kT_270 mask %vT_ptrs_272 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc911) + %kT_274 = ttg.async_commit_group tokens %kT_273 loc(#loc911) + %vT_275 = ttg.memdesc_index %vT[%vT_ptrs_266] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_276 = ttg.async_copy_global_to_local %vT_ptrs_261, %vT_275 mask %vT_ptrs_272 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc912) + %vT_277 = ttg.async_commit_group tokens %vT_276 loc(#loc912) + scf.yield %dq_241, %kT_ptrs_260, %offs_n2_263, %vT_ptrs_261, %offs_n2_243, %vT_ptrs_266, %vT_ptrs_207, %kT_198, %kT_274, %vT_200, %vT_277, %offset_257 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32 loc(#loc921) + } loc(#loc921) + %vT_ptrs_137 = ttng.warp_group_dot_wait %vT_ptrs_136#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> loc(#loc921) + %vT_ptrs_138 = ttg.async_wait {num = 0 : i32} loc(#loc921) + ttg.local_dealloc %vT : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc921) + ttg.local_dealloc %kT_109 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc921) + %kv_indices_139 = tt.addptr %arg_FULL_KV_IDX, %start_m2_block : !tt.ptr, i32 loc(#loc460) + %kv_start_140 = tt.load %kv_indices_139 : !tt.ptr loc(#loc461) + %kv_start_141 = arith.muli %kv_start_140, %c128_i32 : i32 loc(#loc462) + %sparse_kv_num_blocks_142 = tt.addptr %arg_FULL_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc463) + %sparse_kv_num_blocks_143 = tt.load %sparse_kv_num_blocks_142 : !tt.ptr loc(#loc464) + %offs_n2_144 = tt.splat %kv_start_141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc465) + %offs_n2_145 = tt.splat %kv_start_141 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc465) + %offs_n2_146 = arith.addi %offs_n2_144, %offs_n2 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc465) + %offs_n2_147 = arith.addi %offs_n2_145, %offs_n2_82 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc465) + %kT_ptrs_148 = tt.expand_dims %offs_n2_147 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc619) + %kT_ptrs_149 = arith.muli %kT_ptrs_148, %cst_19 : tensor<1x64xi32, #blocked1> loc(#loc620) + %kT_ptrs_150 = tt.addptr %kT_ptrs_88, %kT_ptrs_149 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc621) + %kT_ptrs_151 = tt.broadcast %kT_ptrs_150 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc622) + %kT_ptrs_152 = tt.addptr %kT_ptrs_151, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc622) + %vT_ptrs_153 = tt.addptr %vT_ptrs, %kT_ptrs_149 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc623) + %vT_ptrs_154 = tt.broadcast %vT_ptrs_153 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc624) + %vT_ptrs_155 = tt.addptr %vT_ptrs_154, %kT_ptrs_93 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc624) + %hi_156 = arith.muli %sparse_kv_num_blocks_143, %c2_i32 : i32 loc(#loc625) + %hi_157 = arith.minsi %hi_156, %hi_100 : i32 loc(#loc626) + %kT_158 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc916) + %vT_159 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc917) + %vT_ptrs_160 = arith.cmpi sgt, %hi_157, %c0_i32 : i32 loc(#loc922) + %kT_161 = arith.cmpi slt, %kT_ptrs_148, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc918) + %kT_162 = tt.broadcast %kT_161 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc916) + %kT_163 = ttg.memdesc_index %kT_158[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %vT_ptrs_164 = tt.splat %vT_ptrs_160 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc922) + %vT_ptrs_165 = arith.andi %vT_ptrs_164, %kT_162 : tensor<128x64xi1, #blocked1> loc(#loc922) + %kT_166 = ttg.async_copy_global_to_local %kT_ptrs_152, %kT_163 mask %vT_ptrs_165 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %kT_167 = ttg.async_commit_group tokens %kT_166 loc(#loc916) + %vT_168 = ttg.memdesc_index %vT_159[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_169 = ttg.async_copy_global_to_local %vT_ptrs_155, %vT_168 mask %vT_ptrs_165 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_170 = ttg.async_commit_group tokens %vT_169 loc(#loc917) + %vT_ptrs_171 = arith.cmpi sgt, %hi_157, %c1_i32 : i32 loc(#loc922) + %kT_ptrs_172 = tt.addptr %kT_ptrs_152, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc628) + %vT_ptrs_173 = tt.addptr %vT_ptrs_155, %cst_11 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc629) + %offs_n2_174 = arith.addi %offs_n2_147, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc630) + %kT_175 = tt.expand_dims %offs_n2_174 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc919) + %kT_176 = arith.cmpi slt, %kT_175, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc918) + %kT_177 = tt.broadcast %kT_176 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc916) + %kT_178 = ttg.memdesc_index %kT_158[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %vT_ptrs_179 = tt.splat %vT_ptrs_171 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc922) + %vT_ptrs_180 = arith.andi %vT_ptrs_179, %kT_177 : tensor<128x64xi1, #blocked1> loc(#loc922) + %kT_181 = ttg.async_copy_global_to_local %kT_ptrs_172, %kT_178 mask %vT_ptrs_180 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %kT_182 = ttg.async_commit_group tokens %kT_181 loc(#loc916) + %vT_183 = ttg.memdesc_index %vT_159[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_184 = ttg.async_copy_global_to_local %vT_ptrs_173, %vT_183 mask %vT_ptrs_180 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_185 = ttg.async_commit_group tokens %vT_184 loc(#loc917) + ttng.fence_async_shared {bCluster = false} loc(#loc828) + %vT_ptrs_186:12 = scf.for %vT_ptrs_192 = %c0_i32 to %hi_157 step %c1_i32 iter_args(%vT_ptrs_193 = %vT_ptrs_137, %kT_ptrs_194 = %kT_ptrs_172, %offs_n2_195 = %offs_n2_174, %vT_ptrs_196 = %vT_ptrs_173, %offs_n2_197 = %offs_n2_146, %arg25 = %c1_i32, %arg26 = %c-1_i32, %kT_198 = %kT_167, %kT_199 = %kT_182, %vT_200 = %vT_170, %vT_201 = %vT_185, %arg31 = %c64_i32) -> (tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32) : i32 { + %vT_ptrs_202 = arith.subi %hi_157, %c2_i32 : i32 loc(#loc922) + %vT_ptrs_203 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_202 : i32 loc(#loc922) + %vT_ptrs_204 = arith.subi %hi_157, %c1_i32 : i32 loc(#loc922) + %vT_ptrs_205 = arith.cmpi slt, %vT_ptrs_192, %vT_ptrs_204 : i32 loc(#loc922) + %vT_ptrs_206 = arith.addi %arg26, %c1_i32 : i32 loc(#loc922) + %vT_ptrs_207 = arith.cmpi sge, %vT_ptrs_206, %c3_i32 : i32 loc(#loc922) + %vT_ptrs_208 = arith.select %vT_ptrs_207, %c0_i32, %vT_ptrs_206 : i32 loc(#loc922) + %kT_209 = tt.expand_dims %offs_n2_197 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc919) + %kT_210 = arith.cmpi slt, %kT_209, %kT : tensor<1x64xi32, #mma1> loc(#loc918) + %kT_211 = tt.broadcast %kT_210 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc916) + %kT_212 = ttg.async_wait %kT_198, %vT_200 {num = 2 : i32} loc(#loc916) + %kT_213 = ttg.memdesc_index %kT_158[%vT_ptrs_208] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %dq_214 = ttg.memdesc_trans %kT_213 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc829) + %qk = ttng.warp_group_dot %q_64, %kT_213, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc828) + %qk_215:3 = ttng.warp_group_dot_wait %qk, %q_64, %kT_213 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc828) + %qk_216 = arith.mulf %qk_215#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc830) + %post_mod_scores = arith.select %kT_211, %qk_216, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc831) + %post_mod_scores_217 = arith.mulf %post_mod_scores, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc832) + %p_218 = arith.subf %post_mod_scores_217, %p : tensor<128x64xf32, #mma1> loc(#loc833) + %p_219 = math.exp2 %p_218 : tensor<128x64xf32, #mma1> loc(#loc834) + %vT_220 = ttg.memdesc_index %vT_159[%vT_ptrs_208] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %dp = ttng.warp_group_dot %do_70, %vT_220, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc835) + %dp_221:3 = ttng.warp_group_dot_wait %dp, %do_70, %vT_220 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc835) + %ds_222 = arith.subf %dp_221#0, %ds_108 : tensor<128x64xf32, #mma1> loc(#loc836) + %ds_223 = arith.mulf %p_219, %ds_222 : tensor<128x64xf32, #mma1> loc(#loc837) + %grad_scores = arith.select %kT_211, %ds_223, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc838) + %ds_224 = arith.truncf %grad_scores : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc839) + %ds_225 = ttg.convert_layout %ds_224 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc839) + %dq_226 = ttng.warp_group_dot %ds_225, %dq_214, %vT_ptrs_193 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc840) + %offs_n2_227 = tt.splat %arg31 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc630) + %offs_n2_228 = arith.addi %offs_n2_197, %offs_n2_227 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc630) + %vT_ptrs_229 = arith.addi %vT_ptrs_192, %c1_i32 : i32 loc(#loc922) + %cur_block_idx = arith.divsi %vT_ptrs_229, %c2_i32 : i32 loc(#loc841) + %cur_block = tt.addptr %kv_indices_139, %cur_block_idx : !tt.ptr, i32 loc(#loc842) + %cur_block_230 = tt.load %cur_block, %vT_ptrs_205 evictionPolicy = evict_last : !tt.ptr loc(#loc843) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc844) + %next_block_231 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_143 : i32 loc(#loc845) + %next_block_232 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc846) + %vT_ptrs_233 = arith.andi %vT_ptrs_205, %next_block_231 : i1 loc(#loc922) + %next_block_234 = tt.load %next_block_232, %vT_ptrs_233 evictionPolicy = evict_last : !tt.ptr loc(#loc847) + %needs_jump = arith.addi %vT_ptrs_192, %c2_i32 : i32 loc(#loc848) + %needs_jump_235 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc849) + %needs_jump_236 = arith.cmpi eq, %needs_jump_235, %c0_i32 : i32 loc(#loc850) + %jump_to_block = arith.subi %next_block_234, %cur_block_230 : i32 loc(#loc851) + %jump_to_block_237 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc852) + %jump_to_block_238 = arith.subi %jump_to_block_237, %c64_i32 : i32 loc(#loc853) + %offset = arith.extui %needs_jump_236 : i1 to i32 loc(#loc854) + %offset_239 = arith.muli %jump_to_block_238, %offset : i32 loc(#loc854) + %offset_240 = arith.subi %c1_i32, %offset : i32 loc(#loc855) + %offset_241 = arith.muli %offset_240, %c64_i32 : i32 loc(#loc856) + %offset_242 = arith.addi %offset_239, %offset_241 : i32 loc(#loc857) + %kT_ptrs_243 = arith.muli %offset_242, %c1024_i32 : i32 loc(#loc632) + %kT_ptrs_244 = tt.splat %kT_ptrs_243 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc628) + %kT_ptrs_245 = tt.addptr %kT_ptrs_194, %kT_ptrs_244 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc628) + %vT_ptrs_246 = tt.addptr %vT_ptrs_196, %kT_ptrs_244 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc629) + %offs_n2_247 = tt.splat %offset_242 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc630) + %offs_n2_248 = arith.addi %offs_n2_195, %offs_n2_247 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc630) + %vT_ptrs_249 = arith.addi %arg25, %c1_i32 : i32 loc(#loc922) + %vT_ptrs_250 = arith.cmpi sge, %vT_ptrs_249, %c3_i32 : i32 loc(#loc922) + %vT_ptrs_251 = arith.select %vT_ptrs_250, %c0_i32, %vT_ptrs_249 : i32 loc(#loc922) + %kT_252 = tt.expand_dims %offs_n2_248 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc919) + %kT_253 = arith.cmpi slt, %kT_252, %kT_102 : tensor<1x64xi32, #blocked1> loc(#loc918) + %kT_254 = tt.broadcast %kT_253 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc916) + %kT_255 = ttg.memdesc_index %kT_158[%vT_ptrs_251] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %vT_ptrs_256 = tt.splat %vT_ptrs_203 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc922) + %vT_ptrs_257 = arith.andi %vT_ptrs_256, %kT_254 : tensor<128x64xi1, #blocked1> loc(#loc922) + %kT_258 = ttg.async_copy_global_to_local %kT_ptrs_245, %kT_255 mask %vT_ptrs_257 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc916) + %kT_259 = ttg.async_commit_group tokens %kT_258 loc(#loc916) + %vT_260 = ttg.memdesc_index %vT_159[%vT_ptrs_251] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_261 = ttg.async_copy_global_to_local %vT_ptrs_246, %vT_260 mask %vT_ptrs_257 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc917) + %vT_262 = ttg.async_commit_group tokens %vT_261 loc(#loc917) + scf.yield %dq_226, %kT_ptrs_245, %offs_n2_248, %vT_ptrs_246, %offs_n2_228, %vT_ptrs_251, %vT_ptrs_208, %kT_199, %kT_259, %vT_201, %vT_262, %offset_242 : tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32 loc(#loc922) + } loc(#loc922) + %vT_ptrs_187 = ttng.warp_group_dot_wait %vT_ptrs_186#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma> loc(#loc922) + %vT_ptrs_188 = ttg.async_wait {num = 0 : i32} loc(#loc922) + ttg.local_dealloc %vT_159 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc922) + ttg.local_dealloc %kT_158 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc922) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc467) + %dq_ptrs_189 = tt.addptr %dq_ptrs, %ptr_52 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc467) + %dq_ptrs_190 = tt.broadcast %dq_ptrs_189 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc468) + %dq_ptrs_191 = tt.addptr %dq_ptrs_190, %ptr_58 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc468) + %dq = arith.mulf %vT_ptrs_187, %cst_5 : tensor<128x128xf32, #mma> loc(#loc469) + %11 = arith.cmpi slt, %ptr_56, %cst_0 : tensor<1x128xi32, #blocked> loc(#loc171) + %12 = tt.broadcast %11 : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc172) + %13 = arith.andi %q_62, %12 : tensor<128x128xi1, #blocked> loc(#loc172) + %14 = arith.truncf %dq : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc173) + %15 = ttg.convert_layout %14 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc173) + tt.store %dq_ptrs_191, %15, %13 : tensor<128x128x!tt.ptr, #blocked> loc(#loc173) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc470) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc471) + %offs_n1_37 = tt.splat %start_n1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc471) + %offs_n1_38 = arith.addi %offs_n1, %offs_k : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc471) + %offs_n1_39 = arith.addi %offs_n1_37, %offs_k_36 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> loc(#loc471) + %ptr = tt.expand_dims %offs_n1_38 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc633) + %ptr_40 = tt.expand_dims %offs_n1_39 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> loc(#loc633) + %ptr_41 = arith.muli %ptr, %cst : tensor<128x1xi32, #blocked> loc(#loc634) + %ptr_42 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc635) + %ptr_43 = tt.addptr %ptr_42, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc635) + %ptr_44 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc636) + %ptr_45 = tt.expand_dims %ptr_44 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc636) + %ptr_46 = tt.broadcast %ptr_43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc637) + %ptr_47 = tt.broadcast %ptr_45 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc637) + %ptr_48 = tt.addptr %ptr_46, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc637) + %k = tt.splat %ks1 : i32 -> tensor<128x1xi32, #blocked> loc(#loc638) + %k_49 = tt.splat %ks1 : i32 -> tensor<128x1xi32, #mma1> loc(#loc638) + %k_50 = arith.cmpi slt, %ptr, %k : tensor<128x1xi32, #blocked> loc(#loc638) + %k_51 = tt.broadcast %k_50 : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc639) + %k_52 = tt.load %ptr_48, %k_51, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc639) + %k_53 = ttg.local_alloc %k_52 : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc639) + %ptr_54 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc640) + %ptr_55 = tt.addptr %ptr_54, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc640) + %ptr_56 = tt.broadcast %ptr_55 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc641) + %ptr_57 = tt.addptr %ptr_56, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc641) + %v = tt.load %ptr_57, %k_51, %cst_2 : tensor<128x128x!tt.ptr, #blocked> loc(#loc642) + %v_58 = ttg.local_alloc %v : (tensor<128x128xbf16, #blocked>) -> !ttg.memdesc<128x128xbf16, #shared, #smem> loc(#loc642) + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc474) + %q_adj1 = arith.muli %0, %off_zq : i32 loc(#loc475) + %do_adj1 = arith.muli %7, %off_zq : i32 loc(#loc476) + %off_chz1 = arith.muli %off_zq, %c32_i32 : i32 loc(#loc477) + %q_indices = tt.addptr %arg_Q_IDX, %pid : !tt.ptr, i32 loc(#loc478) + %q_start = tt.load %q_indices, %true : !tt.ptr loc(#loc479) + %q_start_59 = arith.muli %q_start, %c128_i32 : i32 loc(#loc480) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc481) + %sparse_q_num_blocks_60 = tt.load %sparse_q_num_blocks, %true : !tt.ptr loc(#loc482) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc483) + %offs_m1_61 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc483) + %offs_m1_62 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc483) + %offs_m1_63 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc484) + %offs_m1_64 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc484) + %offs_m1_65 = tt.splat %q_start_59 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc484) + %offs_m1_66 = arith.addi %offs_m1_63, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc484) + %offs_m1_67 = arith.addi %offs_m1_64, %offs_m1_61 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc484) + %offs_m1_68 = arith.addi %offs_m1_65, %offs_m1_62 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc484) + %qT_ptrs = tt.expand_dims %offs_m1_67 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc643) + %qT_ptrs_69 = arith.muli %qT_ptrs, %cst_18 : tensor<1x64xi32, #blocked1> loc(#loc644) + %qT_ptrs_70 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc645) + %qT_ptrs_71 = tt.expand_dims %qT_ptrs_70 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc645) + %qT_ptrs_72 = tt.broadcast %qT_ptrs_71 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> loc(#loc646) + %do_ptrs = tt.expand_dims %offs_m1_68 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc647) + %do_ptrs_73 = arith.muli %do_ptrs, %cst_17 : tensor<64x1xi32, #blocked> loc(#loc648) + %do_ptrs_74 = tt.broadcast %ptr_45 : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> loc(#loc649) + %hi = arith.muli %sparse_q_num_blocks_60, %c2_i32 : i32 loc(#loc650) + %hi_75 = arith.addi %ks0, %c63_i32 : i32 loc(#loc858) + %hi_76 = arith.divsi %hi_75, %c64_i32 : i32 loc(#loc859) + %hi_77 = arith.maxsi %hi_76, %c1_i32 : i32 loc(#loc652) + %hi_78 = arith.minsi %hi, %hi_77 : i32 loc(#loc653) + %qT = tt.splat %ks0 : i32 -> tensor<1x64xi32, #mma1> loc(#loc860) + %qT_79 = tt.splat %ks0 : i32 -> tensor<1x64xi32, #blocked1> loc(#loc860) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc655) + %n = arith.remsi %ptr_40, %k_49 : tensor<128x1xi32, #mma1> loc(#loc861) + %tmp27 = tt.broadcast %n : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc657) + %tmp30 = arith.cmpi slt, %n, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc658) + %tmp31 = tt.broadcast %tmp30 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc659) + %tmp32 = arith.extui %tmp30 : tensor<128x1xi1, #mma1> to tensor<128x1xi32, #mma1> loc(#loc660) + %tmp32_80 = arith.cmpi eq, %tmp32, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc660) + %tmp33 = tt.broadcast %tmp32_80 : tensor<128x1xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc661) + %tmp38 = arith.remsi %n, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc662) + %tmp38_81 = arith.cmpi ne, %tmp38, %cst_27 : tensor<128x1xi32, #mma1> loc(#loc663) + %tmp38_82 = arith.divsi %n, %cst_21 : tensor<128x1xi32, #mma1> loc(#loc664) + %tmp38_83 = arith.subi %tmp38_82, %cst_26 : tensor<128x1xi32, #mma1> loc(#loc665) + %tmp38_84 = arith.select %tmp38_81, %tmp38_83, %tmp38_82 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc666) + %tmp38_85 = arith.select %tmp30, %tmp38_84, %tmp38_82 : tensor<128x1xi1, #mma1>, tensor<128x1xi32, #mma1> loc(#loc667) + %tmp39 = tt.broadcast %tmp38_85 : tensor<128x1xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc668) + %do = tt.splat %ks0 : i32 -> tensor<64x1xi32, #blocked> loc(#loc862) + %q_indices_86 = tt.addptr %arg_FULL_Q_IDX, %pid : !tt.ptr, i32 loc(#loc513) + %q_start_87 = tt.load %q_indices_86, %true : !tt.ptr loc(#loc514) + %q_start_88 = arith.muli %q_start_87, %c128_i32 : i32 loc(#loc515) + %sparse_q_num_blocks_89 = tt.addptr %arg_FULL_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc516) + %sparse_q_num_blocks_90 = tt.load %sparse_q_num_blocks_89, %true : !tt.ptr loc(#loc517) + %offs_m1_91 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc518) + %offs_m1_92 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc518) + %offs_m1_93 = tt.splat %q_start_88 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc518) + %offs_m1_94 = arith.addi %offs_m1_91, %offs_m1 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc518) + %offs_m1_95 = arith.addi %offs_m1_92, %offs_m1_61 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc518) + %offs_m1_96 = arith.addi %offs_m1_93, %offs_m1_62 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc518) + %qT_ptrs_97 = tt.expand_dims %offs_m1_95 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc670) + %qT_ptrs_98 = arith.muli %qT_ptrs_97, %cst_18 : tensor<1x64xi32, #blocked1> loc(#loc671) + %do_ptrs_99 = tt.expand_dims %offs_m1_96 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc672) + %do_ptrs_100 = arith.muli %do_ptrs_99, %cst_17 : tensor<64x1xi32, #blocked> loc(#loc673) + %hi_101 = arith.muli %sparse_q_num_blocks_90, %c2_i32 : i32 loc(#loc674) + %hi_102 = arith.minsi %hi_101, %hi_77 : i32 loc(#loc675) + ttng.fence_async_shared {bCluster = false} loc(#loc676) + %dk:2 = scf.for %dk_107 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg20 = %cst_6, %arg21 = %cst_6) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>) : i32 { + %off_hq1_108 = arith.addi %off_hq1, %dk_107 : i32 loc(#loc521) + %q_adj1_109 = arith.muli %off_hq1_108, %c128_i32 : i32 loc(#loc522) + %q_adj1_110 = arith.addi %q_adj1_109, %q_adj1 : i32 loc(#loc523) + %q_adj1_111 = arith.extsi %q_adj1_110 : i32 to i64 loc(#loc524) + %do_adj1_112 = arith.muli %8, %off_hq1_108 : i32 loc(#loc525) + %do_adj1_113 = arith.addi %do_adj1_112, %do_adj1 : i32 loc(#loc526) + %do_adj1_114 = arith.extsi %do_adj1_113 : i32 to i64 loc(#loc527) + %off_chz1_115 = arith.addi %off_chz1, %off_hq1_108 : i32 loc(#loc528) + %off_chz1_116 = arith.muli %off_chz1_115, %ks0 : i32 loc(#loc529) + %off_chz1_117 = arith.extsi %off_chz1_116 : i32 to i64 loc(#loc530) + %Q1 = tt.addptr %arg_Q, %q_adj1_111 : !tt.ptr, i64 loc(#loc531) + %DO1 = tt.addptr %arg_DO, %do_adj1_114 : !tt.ptr, i64 loc(#loc532) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_117 : !tt.ptr, i64 loc(#loc533) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_117 : !tt.ptr, i64 loc(#loc534) + %qT_ptrs_118 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr, #blocked1> loc(#loc678) + %qT_ptrs_119 = tt.addptr %qT_ptrs_118, %qT_ptrs_69 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc678) + %qT_ptrs_120 = tt.broadcast %qT_ptrs_119 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc646) + %qT_ptrs_121 = tt.addptr %qT_ptrs_120, %qT_ptrs_72 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc646) + %do_ptrs_122 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> loc(#loc679) + %do_ptrs_123 = tt.addptr %do_ptrs_122, %do_ptrs_73 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc679) + %do_ptrs_124 = tt.broadcast %do_ptrs_123 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc649) + %do_ptrs_125 = tt.addptr %do_ptrs_124, %do_ptrs_74 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc649) + %lse_126 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc680) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc681) + %qT_127 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc863) + %lse_128 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc682) + %do_129 = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc864) + %Di_130 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc683) + %do_ptrs_131 = arith.cmpi sgt, %hi_78, %c0_i32 : i32 loc(#loc924) + %qT_132 = arith.cmpi slt, %qT_ptrs, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc860) + %qT_133 = tt.broadcast %qT_132 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc863) + %qT_134 = ttg.memdesc_index %qT_127[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %do_ptrs_135 = tt.splat %do_ptrs_131 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc924) + %do_ptrs_136 = arith.andi %do_ptrs_135, %qT_133 : tensor<128x64xi1, #blocked1> loc(#loc924) + %qT_137 = ttg.async_copy_global_to_local %qT_ptrs_121, %qT_134 mask %do_ptrs_136 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %qT_138 = ttg.async_commit_group tokens %qT_137 loc(#loc863) + %lse_139 = arith.cmpi slt, %offs_m1_66, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc655) + %lse_140 = tt.addptr %lse_126, %offs_m1_66 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc680) + %lse_141 = ttg.memdesc_index %lse_128[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %do_ptrs_142 = tt.splat %do_ptrs_131 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %do_ptrs_143 = arith.andi %do_ptrs_142, %lse_139 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %lse_144 = ttg.async_copy_global_to_local %lse_140, %lse_141 mask %do_ptrs_143 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %lse_145 = ttg.async_commit_group tokens %lse_144 loc(#loc682) + %do_146 = arith.cmpi slt, %do_ptrs, %do : tensor<64x1xi32, #blocked> loc(#loc862) + %do_147 = tt.broadcast %do_146 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc864) + %do_148 = ttg.memdesc_index %do_129[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_ptrs_149 = tt.splat %do_ptrs_131 : i1 -> tensor<64x128xi1, #blocked> loc(#loc924) + %do_ptrs_150 = arith.andi %do_ptrs_149, %do_147 : tensor<64x128xi1, #blocked> loc(#loc924) + %do_151 = ttg.async_copy_global_to_local %do_ptrs_125, %do_148 mask %do_ptrs_150 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_152 = ttg.async_commit_group tokens %do_151 loc(#loc864) + %Di_153 = tt.addptr %Di, %offs_m1_66 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc681) + %Di_154 = ttg.memdesc_index %Di_130[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_155 = ttg.async_copy_global_to_local %Di_153, %Di_154 mask %do_ptrs_143 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_156 = ttg.async_commit_group tokens %Di_155 loc(#loc683) + %do_ptrs_157 = arith.cmpi sgt, %hi_78, %c1_i32 : i32 loc(#loc924) + %qT_ptrs_158 = tt.addptr %qT_ptrs_121, %cst_13 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc685) + %do_ptrs_159 = tt.addptr %do_ptrs_125, %cst_14 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc686) + %offs_m1_160 = arith.addi %offs_m1_66, %cst_15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc687) + %offs_m1_161 = arith.addi %offs_m1_67, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc687) + %offs_m1_162 = arith.addi %offs_m1_68, %cst_16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc687) + %qT_163 = tt.expand_dims %offs_m1_161 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc866) + %qT_164 = arith.cmpi slt, %qT_163, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc860) + %qT_165 = tt.broadcast %qT_164 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc863) + %qT_166 = ttg.memdesc_index %qT_127[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %do_ptrs_167 = tt.splat %do_ptrs_157 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc924) + %do_ptrs_168 = arith.andi %do_ptrs_167, %qT_165 : tensor<128x64xi1, #blocked1> loc(#loc924) + %qT_169 = ttg.async_copy_global_to_local %qT_ptrs_158, %qT_166 mask %do_ptrs_168 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %qT_170 = ttg.async_commit_group tokens %qT_169 loc(#loc863) + %lse_171 = arith.cmpi slt, %offs_m1_160, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc655) + %lse_172 = tt.addptr %lse_126, %offs_m1_160 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc680) + %lse_173 = ttg.memdesc_index %lse_128[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %do_ptrs_174 = tt.splat %do_ptrs_157 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %do_ptrs_175 = arith.andi %do_ptrs_174, %lse_171 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %lse_176 = ttg.async_copy_global_to_local %lse_172, %lse_173 mask %do_ptrs_175 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %lse_177 = ttg.async_commit_group tokens %lse_176 loc(#loc682) + %do_178 = tt.expand_dims %offs_m1_162 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc867) + %do_179 = arith.cmpi slt, %do_178, %do : tensor<64x1xi32, #blocked> loc(#loc862) + %do_180 = tt.broadcast %do_179 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc864) + %do_181 = ttg.memdesc_index %do_129[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_ptrs_182 = tt.splat %do_ptrs_157 : i1 -> tensor<64x128xi1, #blocked> loc(#loc924) + %do_ptrs_183 = arith.andi %do_ptrs_182, %do_180 : tensor<64x128xi1, #blocked> loc(#loc924) + %do_184 = ttg.async_copy_global_to_local %do_ptrs_159, %do_181 mask %do_ptrs_183 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_185 = ttg.async_commit_group tokens %do_184 loc(#loc864) + %Di_186 = tt.addptr %Di, %offs_m1_160 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc681) + %Di_187 = ttg.memdesc_index %Di_130[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_188 = ttg.async_copy_global_to_local %Di_186, %Di_187 mask %do_ptrs_175 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_189 = ttg.async_commit_group tokens %Di_188 loc(#loc683) + %do_ptrs_190:22 = scf.for %do_ptrs_265 = %c0_i32 to %hi_78 step %c1_i32 iter_args(%arg23 = %arg21, %arg24 = %arg20, %qT_ptrs_266 = %qT_ptrs_158, %offs_m1_267 = %offs_m1_161, %do_ptrs_268 = %do_ptrs_159, %offs_m1_269 = %offs_m1_162, %offs_m1_270 = %offs_m1_160, %arg30 = %c1_i32, %arg31 = %c-1_i32, %arg32 = %c1_i32, %arg33 = %c-1_i32, %offs_m1_271 = %offs_m1_66, %qT_272 = %qT_138, %qT_273 = %qT_170, %lse_274 = %lse_145, %lse_275 = %lse_177, %do_276 = %do_152, %do_277 = %do_185, %Di_278 = %Di_156, %Di_279 = %Di_189, %arg43 = %c64_i32, %offs_m1_280 = %offs_m1_66) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>) : i32 { + %do_ptrs_281 = arith.subi %hi_78, %c2_i32 : i32 loc(#loc924) + %do_ptrs_282 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_281 : i32 loc(#loc924) + %do_ptrs_283 = arith.subi %hi_78, %c1_i32 : i32 loc(#loc924) + %do_ptrs_284 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_283 : i32 loc(#loc924) + %do_ptrs_285 = arith.addi %arg33, %c1_i32 : i32 loc(#loc924) + %do_ptrs_286 = arith.cmpi sge, %do_ptrs_285, %c2_i32 : i32 loc(#loc924) + %do_ptrs_287 = arith.select %do_ptrs_286, %c0_i32, %do_ptrs_285 : i32 loc(#loc924) + %do_ptrs_288 = arith.addi %arg31, %c1_i32 : i32 loc(#loc924) + %do_ptrs_289 = arith.cmpi sge, %do_ptrs_288, %c3_i32 : i32 loc(#loc924) + %do_ptrs_290 = arith.select %do_ptrs_289, %c0_i32, %do_ptrs_288 : i32 loc(#loc924) + %qT_291 = tt.expand_dims %offs_m1_271 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc866) + %qT_292 = tt.expand_dims %offs_m1_280 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc866) + %qT_293 = arith.cmpi slt, %qT_291, %qT : tensor<1x64xi32, #mma1> loc(#loc860) + %qT_294 = tt.broadcast %qT_293 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc863) + %qT_295 = ttg.async_wait %qT_272, %lse_274, %do_276, %Di_278 {num = 4 : i32} loc(#loc863) + %qT_296 = ttg.memdesc_index %qT_127[%do_ptrs_290] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %dk_297 = ttg.memdesc_trans %qT_296 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc688) + %lse_298 = ttg.memdesc_index %lse_128[%do_ptrs_287] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %lse_299 = ttg.local_load %lse_298 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc682) + %lse_300 = arith.cmpf oeq, %lse_299, %cst_23 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc689) + %lse_301 = arith.select %lse_300, %cst_24, %lse_299 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc690) + %qkT = ttng.warp_group_dot %k_53, %qT_296, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc676) + %qkT_302:3 = ttng.warp_group_dot_wait %qkT, %k_53, %qT_296 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc676) + %qkT_303 = arith.mulf %qkT_302#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc691) + %m = arith.remsi %qT_292, %qT : tensor<1x64xi32, #mma1> loc(#loc868) + %post_mod_scores = arith.select %qT_294, %qkT_303, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc693) + %tmp25 = arith.cmpi slt, %m, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc694) + %tmp27_304 = tt.broadcast %m : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc657) + %tmp27_305 = arith.cmpi sle, %tmp27, %tmp27_304 : tensor<128x64xi32, #mma1> loc(#loc657) + %tmp28 = tt.broadcast %tmp25 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc695) + %tmp28_306 = arith.andi %tmp28, %tmp27_305 : tensor<128x64xi1, #mma1> loc(#loc695) + %tmp29 = arith.cmpi sge, %m, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc696) + %tmp31_307 = tt.broadcast %tmp29 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc659) + %tmp31_308 = arith.andi %tmp31_307, %tmp31 : tensor<128x64xi1, #mma1> loc(#loc659) + %tmp33_309 = arith.andi %tmp31_307, %tmp33 : tensor<128x64xi1, #mma1> loc(#loc661) + %tmp36 = arith.remsi %m, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc697) + %tmp36_310 = arith.cmpi ne, %tmp36, %cst_28 : tensor<1x64xi32, #mma1> loc(#loc698) + %tmp36_311 = arith.divsi %m, %cst_22 : tensor<1x64xi32, #mma1> loc(#loc699) + %tmp36_312 = arith.subi %tmp36_311, %cst_25 : tensor<1x64xi32, #mma1> loc(#loc700) + %tmp36_313 = arith.select %tmp36_310, %tmp36_312, %tmp36_311 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc701) + %tmp36_314 = arith.select %tmp25, %tmp36_313, %tmp36_311 : tensor<1x64xi1, #mma1>, tensor<1x64xi32, #mma1> loc(#loc702) + %tmp39_315 = tt.broadcast %tmp36_314 : tensor<1x64xi32, #mma1> -> tensor<128x64xi32, #mma1> loc(#loc668) + %tmp39_316 = arith.cmpi eq, %tmp39_315, %tmp39 : tensor<128x64xi32, #mma1> loc(#loc668) + %tmp40 = arith.andi %tmp33_309, %tmp39_316 : tensor<128x64xi1, #mma1> loc(#loc703) + %tmp41 = arith.ori %tmp31_308, %tmp40 : tensor<128x64xi1, #mma1> loc(#loc704) + %tmp42 = arith.ori %tmp28_306, %tmp41 : tensor<128x64xi1, #mma1> loc(#loc705) + %post_mod_scores_317 = arith.select %tmp42, %post_mod_scores, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc706) + %post_mod_scores_318 = arith.mulf %post_mod_scores_317, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc707) + %pT = tt.expand_dims %lse_301 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc708) + %pT_319 = tt.broadcast %pT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc709) + %pT_320 = arith.subf %post_mod_scores_318, %pT_319 : tensor<128x64xf32, #mma1> loc(#loc709) + %pT_321 = math.exp2 %pT_320 : tensor<128x64xf32, #mma1> loc(#loc710) + %do_322 = ttg.memdesc_index %do_129[%do_ptrs_290] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %dpT = ttg.memdesc_trans %do_322 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc711) + %dv = arith.truncf %pT_321 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc712) + %dv_323 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc712) + %dv_324 = ttng.warp_group_dot %dv_323, %do_322, %arg24 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc713) + %Di_325 = ttg.memdesc_index %Di_130[%do_ptrs_287] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_326 = ttg.local_load %Di_325 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc683) + %dpT_327 = ttng.warp_group_dot %v_58, %dpT, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc714) + %dpT_328:3 = ttng.warp_group_dot_wait %dpT_327, %v_58, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc714) + %dsT = tt.expand_dims %Di_326 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc715) + %dsT_329 = tt.broadcast %dsT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc716) + %dsT_330 = arith.subf %dpT_328#0, %dsT_329 : tensor<128x64xf32, #mma1> loc(#loc716) + %dsT_331 = arith.mulf %pT_321, %dsT_330 : tensor<128x64xf32, #mma1> loc(#loc717) + %grad_scores = arith.select %qT_294, %dsT_331, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc718) + %dsT_332 = arith.select %tmp42, %grad_scores, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc719) + %dk_333 = arith.truncf %dsT_332 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc720) + %dk_334 = ttg.convert_layout %dk_333 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc720) + %dk_335 = ttng.warp_group_dot %dk_334, %dk_297, %arg23 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc721) + %offs_m1_336 = tt.splat %arg43 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc687) + %offs_m1_337 = arith.addi %offs_m1_280, %offs_m1_336 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc687) + %do_ptrs_338 = arith.addi %do_ptrs_265, %c1_i32 : i32 loc(#loc924) + %cur_block_idx = arith.divsi %do_ptrs_338, %c2_i32 : i32 loc(#loc869) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc870) + %cur_block_339 = tt.load %cur_block, %do_ptrs_284 evictionPolicy = evict_last : !tt.ptr loc(#loc871) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc872) + %next_block_340 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_60 : i32 loc(#loc873) + %next_block_341 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc874) + %do_ptrs_342 = arith.andi %do_ptrs_284, %next_block_340 : i1 loc(#loc924) + %next_block_343 = tt.load %next_block_341, %do_ptrs_342 evictionPolicy = evict_last : !tt.ptr loc(#loc875) + %needs_jump = arith.addi %do_ptrs_265, %c2_i32 : i32 loc(#loc876) + %needs_jump_344 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc877) + %needs_jump_345 = arith.cmpi eq, %needs_jump_344, %c0_i32 : i32 loc(#loc878) + %jump_to_block = arith.subi %next_block_343, %cur_block_339 : i32 loc(#loc879) + %jump_to_block_346 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc880) + %jump_to_block_347 = arith.subi %jump_to_block_346, %c64_i32 : i32 loc(#loc881) + %offset = arith.extui %needs_jump_345 : i1 to i32 loc(#loc882) + %offset_348 = arith.muli %jump_to_block_347, %offset : i32 loc(#loc882) + %offset_349 = arith.subi %c1_i32, %offset : i32 loc(#loc883) + %offset_350 = arith.muli %offset_349, %c64_i32 : i32 loc(#loc884) + %offset_351 = arith.addi %offset_348, %offset_350 : i32 loc(#loc885) + %qT_ptrs_352 = arith.muli %offset_351, %c4096_i32 : i32 loc(#loc723) + %qT_ptrs_353 = tt.splat %qT_ptrs_352 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc685) + %qT_ptrs_354 = tt.addptr %qT_ptrs_266, %qT_ptrs_353 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc685) + %do_ptrs_355 = arith.muli %offset_351, %c128_i32 : i32 loc(#loc724) + %do_ptrs_356 = tt.splat %do_ptrs_355 : i32 -> tensor<64x128xi32, #blocked> loc(#loc686) + %do_ptrs_357 = tt.addptr %do_ptrs_268, %do_ptrs_356 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc686) + %offs_m1_358 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc687) + %offs_m1_359 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc687) + %offs_m1_360 = tt.splat %offset_351 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc687) + %offs_m1_361 = arith.addi %offs_m1_270, %offs_m1_358 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc687) + %offs_m1_362 = arith.addi %offs_m1_267, %offs_m1_359 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc687) + %offs_m1_363 = arith.addi %offs_m1_269, %offs_m1_360 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc687) + %do_ptrs_364 = arith.addi %arg32, %c1_i32 : i32 loc(#loc924) + %do_ptrs_365 = arith.cmpi sge, %do_ptrs_364, %c2_i32 : i32 loc(#loc924) + %do_ptrs_366 = arith.select %do_ptrs_365, %c0_i32, %do_ptrs_364 : i32 loc(#loc924) + %do_ptrs_367 = arith.addi %arg30, %c1_i32 : i32 loc(#loc924) + %do_ptrs_368 = arith.cmpi sge, %do_ptrs_367, %c3_i32 : i32 loc(#loc924) + %do_ptrs_369 = arith.select %do_ptrs_368, %c0_i32, %do_ptrs_367 : i32 loc(#loc924) + %qT_370 = tt.expand_dims %offs_m1_362 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc866) + %qT_371 = arith.cmpi slt, %qT_370, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc860) + %qT_372 = tt.broadcast %qT_371 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc863) + %qT_373 = ttg.memdesc_index %qT_127[%do_ptrs_369] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %do_ptrs_374 = tt.splat %do_ptrs_282 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc924) + %do_ptrs_375 = arith.andi %do_ptrs_374, %qT_372 : tensor<128x64xi1, #blocked1> loc(#loc924) + %qT_376 = ttg.async_copy_global_to_local %qT_ptrs_354, %qT_373 mask %do_ptrs_375 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc863) + %qT_377 = ttg.async_commit_group tokens %qT_376 loc(#loc863) + %lse_378 = arith.cmpi slt, %offs_m1_361, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc655) + %lse_379 = tt.addptr %lse_126, %offs_m1_361 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc680) + %lse_380 = ttg.memdesc_index %lse_128[%do_ptrs_366] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %do_ptrs_381 = tt.splat %do_ptrs_282 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %do_ptrs_382 = arith.andi %do_ptrs_381, %lse_378 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + %lse_383 = ttg.async_copy_global_to_local %lse_379, %lse_380 mask %do_ptrs_382 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc682) + %lse_384 = ttg.async_commit_group tokens %lse_383 loc(#loc682) + %do_385 = tt.expand_dims %offs_m1_363 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc867) + %do_386 = arith.cmpi slt, %do_385, %do : tensor<64x1xi32, #blocked> loc(#loc862) + %do_387 = tt.broadcast %do_386 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc864) + %do_388 = ttg.memdesc_index %do_129[%do_ptrs_369] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_ptrs_389 = tt.splat %do_ptrs_282 : i1 -> tensor<64x128xi1, #blocked> loc(#loc924) + %do_ptrs_390 = arith.andi %do_ptrs_389, %do_387 : tensor<64x128xi1, #blocked> loc(#loc924) + %do_391 = ttg.async_copy_global_to_local %do_ptrs_357, %do_388 mask %do_ptrs_390 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc864) + %do_392 = ttg.async_commit_group tokens %do_391 loc(#loc864) + %Di_393 = tt.addptr %Di, %offs_m1_361 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc681) + %Di_394 = ttg.memdesc_index %Di_130[%do_ptrs_366] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_395 = ttg.async_copy_global_to_local %Di_393, %Di_394 mask %do_ptrs_382 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc683) + %Di_396 = ttg.async_commit_group tokens %Di_395 loc(#loc683) + scf.yield %dk_335, %dv_324, %qT_ptrs_354, %offs_m1_362, %do_ptrs_357, %offs_m1_363, %offs_m1_361, %do_ptrs_369, %do_ptrs_290, %do_ptrs_366, %do_ptrs_287, %offs_m1_270, %qT_273, %qT_377, %lse_275, %lse_384, %do_277, %do_392, %Di_279, %Di_396, %offset_351, %offs_m1_337 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc924) + } loc(#loc924) + %do_ptrs_191:2 = ttng.warp_group_dot_wait %do_ptrs_190#1, %do_ptrs_190#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc924) + %do_ptrs_192 = ttg.async_wait {num = 0 : i32} loc(#loc924) + ttg.local_dealloc %Di_130 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc924) + ttg.local_dealloc %do_129 : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc924) + ttg.local_dealloc %lse_128 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc924) + ttg.local_dealloc %qT_127 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc924) + %qT_ptrs_193 = tt.addptr %qT_ptrs_118, %qT_ptrs_98 : tensor<1x64x!tt.ptr, #blocked1>, tensor<1x64xi32, #blocked1> loc(#loc725) + %qT_ptrs_194 = tt.broadcast %qT_ptrs_193 : tensor<1x64x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> loc(#loc726) + %qT_ptrs_195 = tt.addptr %qT_ptrs_194, %qT_ptrs_72 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc726) + %do_ptrs_196 = tt.addptr %do_ptrs_122, %do_ptrs_100 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> loc(#loc727) + %do_ptrs_197 = tt.broadcast %do_ptrs_196 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> loc(#loc728) + %do_ptrs_198 = tt.addptr %do_ptrs_197, %do_ptrs_74 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc728) + %qT_199 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc886) + %lse_200 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc730) + %do_201 = ttg.local_alloc : () -> !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc887) + %Di_202 = ttg.local_alloc : () -> !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc732) + %do_ptrs_203 = arith.cmpi sgt, %hi_102, %c0_i32 : i32 loc(#loc925) + %qT_204 = arith.cmpi slt, %qT_ptrs_97, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc888) + %qT_205 = tt.broadcast %qT_204 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc886) + %qT_206 = ttg.memdesc_index %qT_199[%c0_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %do_ptrs_207 = tt.splat %do_ptrs_203 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc925) + %do_ptrs_208 = arith.andi %do_ptrs_207, %qT_205 : tensor<128x64xi1, #blocked1> loc(#loc925) + %qT_209 = ttg.async_copy_global_to_local %qT_ptrs_195, %qT_206 mask %do_ptrs_208 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %qT_210 = ttg.async_commit_group tokens %qT_209 loc(#loc886) + %lse_211 = arith.cmpi slt, %offs_m1_94, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc733) + %lse_212 = tt.addptr %lse_126, %offs_m1_94 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc734) + %lse_213 = ttg.memdesc_index %lse_200[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %do_ptrs_214 = tt.splat %do_ptrs_203 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %do_ptrs_215 = arith.andi %do_ptrs_214, %lse_211 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %lse_216 = ttg.async_copy_global_to_local %lse_212, %lse_213 mask %do_ptrs_215 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %lse_217 = ttg.async_commit_group tokens %lse_216 loc(#loc730) + %do_218 = arith.cmpi slt, %do_ptrs_99, %do : tensor<64x1xi32, #blocked> loc(#loc889) + %do_219 = tt.broadcast %do_218 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc887) + %do_220 = ttg.memdesc_index %do_201[%c0_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_ptrs_221 = tt.splat %do_ptrs_203 : i1 -> tensor<64x128xi1, #blocked> loc(#loc925) + %do_ptrs_222 = arith.andi %do_ptrs_221, %do_219 : tensor<64x128xi1, #blocked> loc(#loc925) + %do_223 = ttg.async_copy_global_to_local %do_ptrs_198, %do_220 mask %do_ptrs_222 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_224 = ttg.async_commit_group tokens %do_223 loc(#loc887) + %Di_225 = tt.addptr %Di, %offs_m1_94 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc735) + %Di_226 = ttg.memdesc_index %Di_202[%c0_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_227 = ttg.async_copy_global_to_local %Di_225, %Di_226 mask %do_ptrs_215 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_228 = ttg.async_commit_group tokens %Di_227 loc(#loc732) + %do_ptrs_229 = arith.cmpi sgt, %hi_102, %c1_i32 : i32 loc(#loc925) + %qT_ptrs_230 = tt.addptr %qT_ptrs_195, %cst_13 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc736) + %do_ptrs_231 = tt.addptr %do_ptrs_198, %cst_14 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc737) + %offs_m1_232 = arith.addi %offs_m1_94, %cst_15 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc738) + %offs_m1_233 = arith.addi %offs_m1_95, %cst_12 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc738) + %offs_m1_234 = arith.addi %offs_m1_96, %cst_16 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc738) + %qT_235 = tt.expand_dims %offs_m1_233 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc890) + %qT_236 = arith.cmpi slt, %qT_235, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc888) + %qT_237 = tt.broadcast %qT_236 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc886) + %qT_238 = ttg.memdesc_index %qT_199[%c1_i32] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %do_ptrs_239 = tt.splat %do_ptrs_229 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc925) + %do_ptrs_240 = arith.andi %do_ptrs_239, %qT_237 : tensor<128x64xi1, #blocked1> loc(#loc925) + %qT_241 = ttg.async_copy_global_to_local %qT_ptrs_230, %qT_238 mask %do_ptrs_240 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %qT_242 = ttg.async_commit_group tokens %qT_241 loc(#loc886) + %lse_243 = arith.cmpi slt, %offs_m1_232, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc733) + %lse_244 = tt.addptr %lse_126, %offs_m1_232 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc734) + %lse_245 = ttg.memdesc_index %lse_200[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %do_ptrs_246 = tt.splat %do_ptrs_229 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %do_ptrs_247 = arith.andi %do_ptrs_246, %lse_243 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %lse_248 = ttg.async_copy_global_to_local %lse_244, %lse_245 mask %do_ptrs_247 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %lse_249 = ttg.async_commit_group tokens %lse_248 loc(#loc730) + %do_250 = tt.expand_dims %offs_m1_234 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc891) + %do_251 = arith.cmpi slt, %do_250, %do : tensor<64x1xi32, #blocked> loc(#loc889) + %do_252 = tt.broadcast %do_251 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc887) + %do_253 = ttg.memdesc_index %do_201[%c1_i32] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_ptrs_254 = tt.splat %do_ptrs_229 : i1 -> tensor<64x128xi1, #blocked> loc(#loc925) + %do_ptrs_255 = arith.andi %do_ptrs_254, %do_252 : tensor<64x128xi1, #blocked> loc(#loc925) + %do_256 = ttg.async_copy_global_to_local %do_ptrs_231, %do_253 mask %do_ptrs_255 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_257 = ttg.async_commit_group tokens %do_256 loc(#loc887) + %Di_258 = tt.addptr %Di, %offs_m1_232 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc735) + %Di_259 = ttg.memdesc_index %Di_202[%c1_i32] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_260 = ttg.async_copy_global_to_local %Di_258, %Di_259 mask %do_ptrs_247 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_261 = ttg.async_commit_group tokens %Di_260 loc(#loc732) + %do_ptrs_262:20 = scf.for %do_ptrs_265 = %c0_i32 to %hi_102 step %c1_i32 iter_args(%do_ptrs_266 = %do_ptrs_191#1, %do_ptrs_267 = %do_ptrs_191#0, %qT_ptrs_268 = %qT_ptrs_230, %offs_m1_269 = %offs_m1_233, %do_ptrs_270 = %do_ptrs_231, %offs_m1_271 = %offs_m1_234, %offs_m1_272 = %offs_m1_232, %arg30 = %c1_i32, %arg31 = %c-1_i32, %arg32 = %c1_i32, %arg33 = %c-1_i32, %offs_m1_273 = %offs_m1_94, %qT_274 = %qT_210, %qT_275 = %qT_242, %lse_276 = %lse_217, %lse_277 = %lse_249, %do_278 = %do_224, %do_279 = %do_257, %Di_280 = %Di_228, %Di_281 = %Di_261) -> (tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { + %do_ptrs_282 = arith.subi %hi_102, %c2_i32 : i32 loc(#loc925) + %do_ptrs_283 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_282 : i32 loc(#loc925) + %do_ptrs_284 = arith.subi %hi_102, %c1_i32 : i32 loc(#loc925) + %do_ptrs_285 = arith.cmpi slt, %do_ptrs_265, %do_ptrs_284 : i32 loc(#loc925) + %do_ptrs_286 = arith.addi %arg33, %c1_i32 : i32 loc(#loc925) + %do_ptrs_287 = arith.cmpi sge, %do_ptrs_286, %c2_i32 : i32 loc(#loc925) + %do_ptrs_288 = arith.select %do_ptrs_287, %c0_i32, %do_ptrs_286 : i32 loc(#loc925) + %do_ptrs_289 = arith.addi %arg31, %c1_i32 : i32 loc(#loc925) + %do_ptrs_290 = arith.cmpi sge, %do_ptrs_289, %c3_i32 : i32 loc(#loc925) + %do_ptrs_291 = arith.select %do_ptrs_290, %c0_i32, %do_ptrs_289 : i32 loc(#loc925) + %qT_292 = tt.expand_dims %offs_m1_273 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xi32, #mma1> loc(#loc890) + %qT_293 = arith.cmpi slt, %qT_292, %qT : tensor<1x64xi32, #mma1> loc(#loc888) + %qT_294 = tt.broadcast %qT_293 : tensor<1x64xi1, #mma1> -> tensor<128x64xi1, #mma1> loc(#loc886) + %qT_295 = ttg.async_wait %qT_274, %lse_276, %do_278, %Di_280 {num = 4 : i32} loc(#loc886) + %qT_296 = ttg.memdesc_index %qT_199[%do_ptrs_291] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %dk_297 = ttg.memdesc_trans %qT_296 {order = array} : !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc739) + %lse_298 = ttg.memdesc_index %lse_200[%do_ptrs_288] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %lse_299 = ttg.local_load %lse_298 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc730) + %lse_300 = arith.cmpf oeq, %lse_299, %cst_23 : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc740) + %lse_301 = arith.select %lse_300, %cst_24, %lse_299 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc741) + %qkT = ttng.warp_group_dot %k_53, %qT_296, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc742) + %qkT_302:3 = ttng.warp_group_dot_wait %qkT, %k_53, %qT_296 {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc742) + %qkT_303 = arith.mulf %qkT_302#0, %cst_8 : tensor<128x64xf32, #mma1> loc(#loc743) + %post_mod_scores = arith.select %qT_294, %qkT_303, %cst_9 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc744) + %post_mod_scores_304 = arith.mulf %post_mod_scores, %cst_10 : tensor<128x64xf32, #mma1> loc(#loc745) + %pT = tt.expand_dims %lse_301 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc746) + %pT_305 = tt.broadcast %pT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc747) + %pT_306 = arith.subf %post_mod_scores_304, %pT_305 : tensor<128x64xf32, #mma1> loc(#loc747) + %pT_307 = math.exp2 %pT_306 : tensor<128x64xf32, #mma1> loc(#loc748) + %do_308 = ttg.memdesc_index %do_201[%do_ptrs_291] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %dpT = ttg.memdesc_trans %do_308 {order = array} : !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc749) + %dv = arith.truncf %pT_307 : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc750) + %dv_309 = ttg.convert_layout %dv : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc750) + %dv_310 = ttng.warp_group_dot %dv_309, %do_308, %do_ptrs_267 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc751) + %Di_311 = ttg.memdesc_index %Di_202[%do_ptrs_288] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_312 = ttg.local_load %Di_311 token %qT_295 : !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc732) + %dpT_313 = ttng.warp_group_dot %v_58, %dpT, %cst_7 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xbf16, #shared, #smem> * !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> -> tensor<128x64xf32, #mma1> loc(#loc752) + %dpT_314:3 = ttng.warp_group_dot_wait %dpT_313, %v_58, %dpT {pendings = 0 : i32} : tensor<128x64xf32, #mma1>, !ttg.memdesc<128x128xbf16, #shared, #smem>, !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc752) + %dsT = tt.expand_dims %Di_312 {axis = 0 : i32} : tensor<64xf32, #ttg.slice<{dim = 0, parent = #mma1}>> -> tensor<1x64xf32, #mma1> loc(#loc753) + %dsT_315 = tt.broadcast %dsT : tensor<1x64xf32, #mma1> -> tensor<128x64xf32, #mma1> loc(#loc754) + %dsT_316 = arith.subf %dpT_314#0, %dsT_315 : tensor<128x64xf32, #mma1> loc(#loc754) + %dsT_317 = arith.mulf %pT_307, %dsT_316 : tensor<128x64xf32, #mma1> loc(#loc755) + %grad_scores = arith.select %qT_294, %dsT_317, %cst_7 : tensor<128x64xi1, #mma1>, tensor<128x64xf32, #mma1> loc(#loc756) + %dk_318 = arith.truncf %grad_scores : tensor<128x64xf32, #mma1> to tensor<128x64xbf16, #mma1> loc(#loc757) + %dk_319 = ttg.convert_layout %dk_318 : tensor<128x64xbf16, #mma1> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc757) + %dk_320 = ttng.warp_group_dot %dk_319, %dk_297, %do_ptrs_266 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> -> tensor<128x128xf32, #mma> loc(#loc758) + %do_ptrs_321 = arith.addi %do_ptrs_265, %c1_i32 : i32 loc(#loc925) + %cur_block_idx = arith.divsi %do_ptrs_321, %c2_i32 : i32 loc(#loc892) + %cur_block = tt.addptr %q_indices_86, %cur_block_idx : !tt.ptr, i32 loc(#loc893) + %cur_block_322 = tt.load %cur_block, %do_ptrs_285 evictionPolicy = evict_last : !tt.ptr loc(#loc894) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc895) + %next_block_323 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_90 : i32 loc(#loc896) + %next_block_324 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc897) + %do_ptrs_325 = arith.andi %do_ptrs_285, %next_block_323 : i1 loc(#loc925) + %next_block_326 = tt.load %next_block_324, %do_ptrs_325 evictionPolicy = evict_last : !tt.ptr loc(#loc898) + %needs_jump = arith.addi %do_ptrs_265, %c2_i32 : i32 loc(#loc899) + %needs_jump_327 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc900) + %needs_jump_328 = arith.cmpi eq, %needs_jump_327, %c0_i32 : i32 loc(#loc901) + %jump_to_block = arith.subi %next_block_326, %cur_block_322 : i32 loc(#loc902) + %jump_to_block_329 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc903) + %jump_to_block_330 = arith.subi %jump_to_block_329, %c64_i32 : i32 loc(#loc904) + %offset = arith.extui %needs_jump_328 : i1 to i32 loc(#loc905) + %offset_331 = arith.muli %jump_to_block_330, %offset : i32 loc(#loc905) + %offset_332 = arith.subi %c1_i32, %offset : i32 loc(#loc906) + %offset_333 = arith.muli %offset_332, %c64_i32 : i32 loc(#loc907) + %offset_334 = arith.addi %offset_331, %offset_333 : i32 loc(#loc908) + %qT_ptrs_335 = arith.muli %offset_334, %c4096_i32 : i32 loc(#loc760) + %qT_ptrs_336 = tt.splat %qT_ptrs_335 : i32 -> tensor<128x64xi32, #blocked1> loc(#loc736) + %qT_ptrs_337 = tt.addptr %qT_ptrs_268, %qT_ptrs_336 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> loc(#loc736) + %do_ptrs_338 = arith.muli %offset_334, %c128_i32 : i32 loc(#loc761) + %do_ptrs_339 = tt.splat %do_ptrs_338 : i32 -> tensor<64x128xi32, #blocked> loc(#loc737) + %do_ptrs_340 = tt.addptr %do_ptrs_270, %do_ptrs_339 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> loc(#loc737) + %offs_m1_341 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc738) + %offs_m1_342 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc738) + %offs_m1_343 = tt.splat %offset_334 : i32 -> tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc738) + %offs_m1_344 = arith.addi %offs_m1_272, %offs_m1_341 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc738) + %offs_m1_345 = arith.addi %offs_m1_269, %offs_m1_342 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc738) + %offs_m1_346 = arith.addi %offs_m1_271, %offs_m1_343 : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc738) + %do_ptrs_347 = arith.addi %arg32, %c1_i32 : i32 loc(#loc925) + %do_ptrs_348 = arith.cmpi sge, %do_ptrs_347, %c2_i32 : i32 loc(#loc925) + %do_ptrs_349 = arith.select %do_ptrs_348, %c0_i32, %do_ptrs_347 : i32 loc(#loc925) + %do_ptrs_350 = arith.addi %arg30, %c1_i32 : i32 loc(#loc925) + %do_ptrs_351 = arith.cmpi sge, %do_ptrs_350, %c3_i32 : i32 loc(#loc925) + %do_ptrs_352 = arith.select %do_ptrs_351, %c0_i32, %do_ptrs_350 : i32 loc(#loc925) + %qT_353 = tt.expand_dims %offs_m1_345 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> loc(#loc890) + %qT_354 = arith.cmpi slt, %qT_353, %qT_79 : tensor<1x64xi32, #blocked1> loc(#loc888) + %qT_355 = tt.broadcast %qT_354 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> loc(#loc886) + %qT_356 = ttg.memdesc_index %qT_199[%do_ptrs_352] : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %do_ptrs_357 = tt.splat %do_ptrs_283 : i1 -> tensor<128x64xi1, #blocked1> loc(#loc925) + %do_ptrs_358 = arith.andi %do_ptrs_357, %qT_355 : tensor<128x64xi1, #blocked1> loc(#loc925) + %qT_359 = ttg.async_copy_global_to_local %qT_ptrs_337, %qT_356 mask %do_ptrs_358 other %cst_3 : tensor<128x64x!tt.ptr, #blocked1> -> <128x64xbf16, #shared1, #smem, mutable, 3x128x64> loc(#loc886) + %qT_360 = ttg.async_commit_group tokens %qT_359 loc(#loc886) + %lse_361 = arith.cmpi slt, %offs_m1_344, %lse : tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc733) + %lse_362 = tt.addptr %lse_126, %offs_m1_344 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc734) + %lse_363 = ttg.memdesc_index %lse_200[%do_ptrs_349] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %do_ptrs_364 = tt.splat %do_ptrs_283 : i1 -> tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %do_ptrs_365 = arith.andi %do_ptrs_364, %lse_361 : tensor<64xi1, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc925) + %lse_366 = ttg.async_copy_global_to_local %lse_362, %lse_363 mask %do_ptrs_365 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc730) + %lse_367 = ttg.async_commit_group tokens %lse_366 loc(#loc730) + %do_368 = tt.expand_dims %offs_m1_346 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> loc(#loc891) + %do_369 = arith.cmpi slt, %do_368, %do : tensor<64x1xi32, #blocked> loc(#loc889) + %do_370 = tt.broadcast %do_369 : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> loc(#loc887) + %do_371 = ttg.memdesc_index %do_201[%do_ptrs_352] : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> -> !ttg.memdesc<64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_ptrs_372 = tt.splat %do_ptrs_283 : i1 -> tensor<64x128xi1, #blocked> loc(#loc925) + %do_ptrs_373 = arith.andi %do_ptrs_372, %do_370 : tensor<64x128xi1, #blocked> loc(#loc925) + %do_374 = ttg.async_copy_global_to_local %do_ptrs_340, %do_371 mask %do_ptrs_373 other %cst_4 : tensor<64x128x!tt.ptr, #blocked> -> <64x128xbf16, #shared, #smem, mutable, 3x64x128> loc(#loc887) + %do_375 = ttg.async_commit_group tokens %do_374 loc(#loc887) + %Di_376 = tt.addptr %Di, %offs_m1_344 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>> loc(#loc735) + %Di_377 = ttg.memdesc_index %Di_202[%do_ptrs_349] : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> -> !ttg.memdesc<64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_378 = ttg.async_copy_global_to_local %Di_376, %Di_377 mask %do_ptrs_365 : tensor<64x!tt.ptr, #ttg.slice<{dim = 0, parent = #mma1}>> -> <64xf32, #shared2, #smem, mutable, 2x64> loc(#loc732) + %Di_379 = ttg.async_commit_group tokens %Di_378 loc(#loc732) + scf.yield %dk_320, %dv_310, %qT_ptrs_337, %offs_m1_345, %do_ptrs_340, %offs_m1_346, %offs_m1_344, %do_ptrs_352, %do_ptrs_291, %do_ptrs_349, %do_ptrs_288, %offs_m1_272, %qT_275, %qT_360, %lse_277, %lse_367, %do_279, %do_375, %Di_281, %Di_379 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<64x128x!tt.ptr, #blocked>, tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, i32, i32, i32, i32, tensor<64xi32, #ttg.slice<{dim = 0, parent = #mma1}>>, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc925) + } loc(#loc925) + %do_ptrs_263:2 = ttng.warp_group_dot_wait %do_ptrs_262#1, %do_ptrs_262#0 {pendings = 0 : i32} : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc925) + %do_ptrs_264 = ttg.async_wait {num = 0 : i32} loc(#loc925) + ttg.local_dealloc %Di_202 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc925) + ttg.local_dealloc %do_201 : !ttg.memdesc<3x64x128xbf16, #shared, #smem, mutable> loc(#loc925) + ttg.local_dealloc %lse_200 : !ttg.memdesc<2x64xf32, #shared2, #smem, mutable> loc(#loc925) + ttg.local_dealloc %qT_199 : !ttg.memdesc<3x128x64xbf16, #shared1, #smem, mutable> loc(#loc925) + scf.yield %do_ptrs_263#0, %do_ptrs_263#1 : tensor<128x128xf32, #mma>, tensor<128x128xf32, #mma> loc(#loc289) + } loc(#loc677) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc583) + %dv_ptrs_103 = tt.addptr %dv_ptrs, %ptr_41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> loc(#loc583) + %dv_ptrs_104 = tt.broadcast %dv_ptrs_103 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc584) + %dv_ptrs_105 = tt.addptr %dv_ptrs_104, %ptr_47 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc584) + %11 = arith.cmpi slt, %ptr_45, %cst_0 : tensor<1x128xi32, #blocked> loc(#loc292) + %12 = tt.broadcast %11 : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc293) + %13 = arith.andi %k_51, %12 : tensor<128x128xi1, #blocked> loc(#loc293) + %14 = arith.truncf %dk#0 : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc294) + %15 = ttg.convert_layout %14 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc294) + tt.store %dv_ptrs_105, %15, %13 : tensor<128x128x!tt.ptr, #blocked> loc(#loc294) + %dk_106 = arith.mulf %dk#1, %cst_5 : tensor<128x128xf32, #mma> loc(#loc585) + %16 = tt.splat %k_adj : i32 -> tensor<1x128xi32, #blocked> loc(#loc296) + %17 = arith.addi %ptr_45, %16 : tensor<1x128xi32, #blocked> loc(#loc296) + %18 = tt.broadcast %17 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc297) + %19 = tt.broadcast %ptr_41 : tensor<128x1xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc297) + %20 = arith.addi %18, %19 : tensor<128x128xi32, #blocked> loc(#loc297) + %21 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked> loc(#loc298) + %22 = tt.addptr %21, %20 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc298) + %23 = arith.truncf %dk_106 : tensor<128x128xf32, #mma> to tensor<128x128xbf16, #mma> loc(#loc299) + %24 = ttg.convert_layout %23 : tensor<128x128xbf16, #mma> -> tensor<128x128xbf16, #blocked> loc(#loc299) + tt.store %22, %24, %k_51 : tensor<128x128x!tt.ptr, #blocked> loc(#loc299) + } loc(#loc28) + tt.return loc(#loc300) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":94:54) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:74) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:66) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:100) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:91) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:82) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:59) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:111) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":100:58) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":111:24) +#loc12 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":112:36) +#loc14 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc15 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":113:34) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":115:27) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":116:28) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:25) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:59) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:50) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:37) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:61) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":131:9) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":132:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":133:10) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":136:26) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:14) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:7) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":140:24) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:29) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:54) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:44) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":145:35) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:30) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:52) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:40) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:63) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:32) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:55) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:42) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:66) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:30) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:35) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:46) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:56) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":163:17) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":164:19) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":167:19) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":168:21) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":169:25) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":174:36) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":175:29) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:27) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":178:107) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:38) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:20) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:56) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:49) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:52) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:23) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":179:111) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:58) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:34) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:25) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:33) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:26) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:30) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:50) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":191:18) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":195:30) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:27) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:41) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:53) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:39) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:42) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:29) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:26) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":207:12) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:37) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:18) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:56) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:49) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:18) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:49) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:43) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:90) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:101) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:63) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:52) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":458:105) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":405:12) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":762:21) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":467:46) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":481:22) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":483:23) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":484:22) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":485:23) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":487:22) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:70) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:79) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:91) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:99) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:102) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:119) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":495:25) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:39) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:22) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:19) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:23) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":510:104) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":397:28) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:19) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":415:19) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":417:19) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:41) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":459:19) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:30) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":461:14) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":464:46) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":476:79) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":486:22) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":488:24) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":489:23) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:70) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:79) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:91) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:99) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:102) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:119) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":496:24) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":497:23) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":498:23) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":503:69) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":506:27) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:21) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":512:20) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:14) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":520:71) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":531:43) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":533:15) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:21) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":752:33) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":411:64) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:38) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:24) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:109) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:113) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:55) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:25) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:30) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:35) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:60) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:34) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:48) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:63) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:29) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:47) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:61) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:42) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:28) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":214:39) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:31) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:45) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:62) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:43) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":218:33) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":226:16) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:24) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:56) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":232:14) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:87) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:69) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:30) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":252:25) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":253:29) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":256:107) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":257:107) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:32) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:56) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:59) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:34) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":286:32) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:30) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:43) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:55) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:42) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:45) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:32) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:26) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":298:16) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:37) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:56) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:49) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:27) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:38) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:51) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:42) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:87) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:98) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:61) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":651:105) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":600:12) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:52) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":665:46) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":681:25) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":684:24) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":685:24) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":686:25) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":687:24) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:70) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:79) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:91) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:99) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:102) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:119) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":693:25) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":705:99) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":306:41) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:34) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:47) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:64) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:46) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":310:36) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":318:20) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":658:20) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":262:30) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:51) +#loc228 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:34) +#loc229 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:44) +#loc230 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:67) +#loc231 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:36) +#loc232 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:46) +#loc233 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:70) +#loc234 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:39) +#loc235 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:50) +#loc236 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:60) +#loc237 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":271:21) +#loc238 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":272:23) +#loc239 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":275:25) +#loc240 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":276:29) +#loc241 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:18) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:19) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:28) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:29) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:22) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:21) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":592:28) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:19) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:19) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":610:19) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:41) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:52) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:26) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:46) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":660:15) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":662:46) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":674:78) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":679:24) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":682:24) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":683:25) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:70) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:79) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:91) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:99) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:102) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:119) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":694:24) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":695:24) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":696:24) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":700:69) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":703:27) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:44) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:40) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:22) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:29) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:24) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:43) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:20) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:25) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:22) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:16) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":723:70) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":737:45) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:24) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:43) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":605:62) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:28) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:28) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":303:12) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:23) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:55) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:71) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:61) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:30) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":334:14) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:55) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:69) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:29) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:99) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:4) +#loc320 = loc("pid"(#loc11)) +#loc321 = loc("NUM_KV_BLOCKS"(#loc13)) +#loc322 = loc("NUM_Q_BLOCKS"(#loc15)) +#loc323 = loc("off_zq"(#loc16)) +#loc324 = loc("off_hkv"(#loc17)) +#loc325 = loc("k_adj"(#loc18)) +#loc326 = loc("k_adj"(#loc19)) +#loc327 = loc("dv_adj"(#loc20)) +#loc328 = loc("dv_adj"(#loc21)) +#loc329 = loc("dv_adj"(#loc22)) +#loc330 = loc("K"(#loc23)) +#loc331 = loc("V"(#loc24)) +#loc332 = loc("DV"(#loc25)) +#loc333 = loc("offs_k"(#loc26)) +#loc334 = loc("off_pid"(#loc29)) +#loc335 = loc("off_hq2"(#loc30)) +#loc336 = loc("off_hq2"(#loc31)) +#loc337 = loc("off_hq2"(#loc32)) +#loc338 = loc("start_m2_block"(#loc33)) +#loc339 = loc("q_adj2"(#loc34)) +#loc340 = loc("q_adj2"(#loc35)) +#loc341 = loc("q_adj2"(#loc36)) +#loc342 = loc("q_adj2"(#loc37)) +#loc343 = loc("do_adj2"(#loc38)) +#loc344 = loc("do_adj2"(#loc39)) +#loc345 = loc("do_adj2"(#loc40)) +#loc346 = loc("do_adj2"(#loc41)) +#loc347 = loc("off_chz2"(#loc42)) +#loc348 = loc("off_chz2"(#loc43)) +#loc349 = loc("off_chz2"(#loc44)) +#loc350 = loc("off_chz2"(#loc45)) +#loc351 = loc("Q2"(#loc46)) +#loc352 = loc("DO2"(#loc47)) +#loc353 = loc("DQ2"(#loc48)) +#loc354 = loc("LSE2"(#loc49)) +#loc355 = loc("DELTA2"(#loc50)) +#loc356 = loc("start_m2"(#loc51)) +#loc357 = loc("offs_m2"(#loc52)) +#loc358 = loc("ptr"(#loc53)) +#loc359 = loc("q"(#loc54)) +#loc360 = loc("ptr"(#loc55)) +#loc361 = loc("ptr"(#loc56)) +#loc362 = loc("ptr"(#loc57)) +#loc363 = loc("ptr"(#loc58)) +#loc364 = loc("do"(#loc61)) +#loc365 = loc("Di"(#loc62)) +#loc366 = loc("Di"(#loc63)) +#loc367 = loc("Di"(#loc64)) +#loc368 = loc("lse"(#loc65)) +#loc369 = loc("lse"(#loc66)) +#loc370 = loc("lse"(#loc67)) +#loc371 = loc("lse"(#loc68)) +#loc372 = loc("lse"(#loc69)) +#loc373 = loc("kv_indices"(#loc70)) +#loc374 = loc("kv_start"(#loc71)) +#loc375 = loc("kv_start"(#loc72)) +#loc376 = loc("sparse_kv_num_blocks"(#loc73)) +#loc377 = loc("sparse_kv_num_blocks"(#loc74)) +#loc378 = loc("offs_n2"(#loc75)) +#loc379 = loc("offs_n2"(#loc76)) +#loc380 = loc("kT_ptrs"(#loc77)) +#loc381 = loc("dq"(#loc78)) +#loc382 = loc("kT_ptrs"(#loc79)) +#loc383 = loc("kT_ptrs"(#loc80)) +#loc384 = loc("kT_ptrs"(#loc81)) +#loc385 = loc("kT_ptrs"(#loc82)) +#loc386 = loc("vT_ptrs"(#loc83)) +#loc387 = loc("vT_ptrs"(#loc84)) +#loc388 = loc("hi"(#loc85)) +#loc389 = loc("hi"(#loc86)) +#loc390 = loc("hi"(#loc87)) +#loc391 = loc("hi"(#loc88)) +#loc392 = loc("kT"(#loc90)) +#loc393 = loc("dq"(#loc91)) +#loc394 = loc("m"(#loc93)) +#loc395 = loc("tmp3"(#loc94)) +#loc396 = loc("tmp5"(#loc95)) +#loc397 = loc("tmp6"(#loc96)) +#loc398 = loc("tmp7"(#loc97)) +#loc399 = loc("tmp9"(#loc98)) +#loc400 = loc("tmp14"(#loc99)) +#loc401 = loc("tmp14"(#loc100)) +#loc402 = loc("tmp14"(#loc101)) +#loc403 = loc("tmp14"(#loc102)) +#loc404 = loc("tmp14"(#loc103)) +#loc405 = loc("tmp14"(#loc104)) +#loc406 = loc("tmp17"(#loc105)) +#loc407 = loc("p"(#loc106)) +#loc408 = loc("ds"(#loc107)) +#loc409 = loc("ds"(#loc108)) +#loc410 = loc("vT"(#loc110)) +#loc411 = loc("dq"(#loc111)) +#loc412 = loc("kT_ptrs"(#loc112)) +#loc413 = loc("vT_ptrs"(#loc113)) +#loc414 = loc("offs_n2"(#loc114)) +#loc415 = loc("qk"(#loc116)) +#loc416 = loc("dq"(#loc117)) +#loc417 = loc("qk"(#loc118)) +#loc418 = loc("n"(#loc119)) +#loc419 = loc("post_mod_scores"(#loc120)) +#loc420 = loc("tmp8"(#loc121)) +#loc421 = loc("tmp10"(#loc122)) +#loc422 = loc("tmp11"(#loc123)) +#loc423 = loc("tmp16"(#loc124)) +#loc424 = loc("tmp16"(#loc125)) +#loc425 = loc("tmp16"(#loc126)) +#loc426 = loc("tmp16"(#loc127)) +#loc427 = loc("tmp16"(#loc128)) +#loc428 = loc("tmp16"(#loc129)) +#loc429 = loc("tmp18"(#loc130)) +#loc430 = loc("tmp19"(#loc131)) +#loc431 = loc("tmp20"(#loc132)) +#loc432 = loc("post_mod_scores"(#loc133)) +#loc433 = loc("post_mod_scores"(#loc134)) +#loc434 = loc("p"(#loc135)) +#loc435 = loc("dp"(#loc136)) +#loc436 = loc("ds"(#loc137)) +#loc437 = loc("grad_scores"(#loc138)) +#loc438 = loc("ds"(#loc139)) +#loc439 = loc("ds"(#loc140)) +#loc440 = loc("dq"(#loc141)) +#loc441 = loc("cur_block_idx"(#loc142)) +#loc442 = loc("offset"(#loc143)) +#loc443 = loc("cur_block"(#loc144)) +#loc444 = loc("cur_block"(#loc145)) +#loc445 = loc("next_block"(#loc146)) +#loc446 = loc("next_block"(#loc147)) +#loc447 = loc("next_block"(#loc148)) +#loc448 = loc("next_block"(#loc149)) +#loc449 = loc("needs_jump"(#loc150)) +#loc450 = loc("needs_jump"(#loc151)) +#loc451 = loc("needs_jump"(#loc152)) +#loc452 = loc("jump_to_block"(#loc153)) +#loc453 = loc("jump_to_block"(#loc154)) +#loc454 = loc("jump_to_block"(#loc155)) +#loc455 = loc("offset"(#loc156)) +#loc456 = loc("offset"(#loc157)) +#loc457 = loc("offset"(#loc158)) +#loc458 = loc("offset"(#loc159)) +#loc459 = loc("kT_ptrs"(#loc160)) +#loc460 = loc("kv_indices"(#loc161)) +#loc461 = loc("kv_start"(#loc162)) +#loc462 = loc("kv_start"(#loc163)) +#loc463 = loc("sparse_kv_num_blocks"(#loc164)) +#loc464 = loc("sparse_kv_num_blocks"(#loc165)) +#loc465 = loc("offs_n2"(#loc166)) +#loc466 = loc("dq"(#loc167)) +#loc467 = loc("dq_ptrs"(#loc168)) +#loc468 = loc("dq_ptrs"(#loc169)) +#loc469 = loc("dq"(#loc170)) +#loc470 = loc("start_n1"(#loc174)) +#loc471 = loc("offs_n1"(#loc175)) +#loc472 = loc("k"(#loc176)) +#loc473 = loc("v"(#loc177)) +#loc474 = loc("off_hq1"(#loc178)) +#loc475 = loc("q_adj1"(#loc179)) +#loc476 = loc("do_adj1"(#loc180)) +#loc477 = loc("off_chz1"(#loc181)) +#loc478 = loc("q_indices"(#loc182)) +#loc479 = loc("q_start"(#loc183)) +#loc480 = loc("q_start"(#loc184)) +#loc481 = loc("sparse_q_num_blocks"(#loc185)) +#loc482 = loc("sparse_q_num_blocks"(#loc186)) +#loc483 = loc("offs_m1"(#loc187)) +#loc484 = loc("offs_m1"(#loc188)) +#loc485 = loc("qT_ptrs"(#loc189)) +#loc486 = loc("qT_ptrs"(#loc191)) +#loc487 = loc("qT_ptrs"(#loc192)) +#loc488 = loc("qT_ptrs"(#loc193)) +#loc489 = loc("do_ptrs"(#loc194)) +#loc490 = loc("do_ptrs"(#loc195)) +#loc491 = loc("do_ptrs"(#loc196)) +#loc492 = loc("hi"(#loc197)) +#loc493 = loc("hi"(#loc198)) +#loc494 = loc("hi"(#loc199)) +#loc495 = loc("hi"(#loc200)) +#loc496 = loc("qT"(#loc201)) +#loc497 = loc(callsite(#loc202 at #loc190)) +#loc498 = loc("lse"(#loc203)) +#loc499 = loc("n"(#loc204)) +#loc500 = loc("tmp27"(#loc205)) +#loc501 = loc("tmp30"(#loc206)) +#loc502 = loc("tmp31"(#loc207)) +#loc503 = loc("tmp32"(#loc208)) +#loc504 = loc("tmp33"(#loc209)) +#loc505 = loc("tmp38"(#loc210)) +#loc506 = loc("tmp38"(#loc211)) +#loc507 = loc("tmp38"(#loc212)) +#loc508 = loc("tmp38"(#loc213)) +#loc509 = loc("tmp38"(#loc214)) +#loc510 = loc("tmp38"(#loc215)) +#loc511 = loc("tmp39"(#loc216)) +#loc512 = loc("do"(#loc217)) +#loc513 = loc("q_indices"(#loc218)) +#loc514 = loc("q_start"(#loc219)) +#loc515 = loc("q_start"(#loc220)) +#loc516 = loc("sparse_q_num_blocks"(#loc221)) +#loc517 = loc("sparse_q_num_blocks"(#loc222)) +#loc518 = loc("offs_m1"(#loc223)) +#loc519 = loc("qkT"(#loc225)) +#loc520 = loc("dv"(#loc226)) +#loc521 = loc("off_hq1"(#loc227)) +#loc522 = loc("q_adj1"(#loc228)) +#loc523 = loc("q_adj1"(#loc229)) +#loc524 = loc("q_adj1"(#loc230)) +#loc525 = loc("do_adj1"(#loc231)) +#loc526 = loc("do_adj1"(#loc232)) +#loc527 = loc("do_adj1"(#loc233)) +#loc528 = loc("off_chz1"(#loc234)) +#loc529 = loc("off_chz1"(#loc235)) +#loc530 = loc("off_chz1"(#loc236)) +#loc531 = loc("Q1"(#loc237)) +#loc532 = loc("DO1"(#loc238)) +#loc533 = loc("LSE1"(#loc239)) +#loc534 = loc("DELTA1"(#loc240)) +#loc535 = loc("qT_ptrs"(#loc241)) +#loc536 = loc("do_ptrs"(#loc242)) +#loc537 = loc("lse"(#loc243)) +#loc538 = loc("Di"(#loc244)) +#loc539 = loc("lse"(#loc245)) +#loc540 = loc("Di"(#loc246)) +#loc541 = loc("dk"(#loc247)) +#loc542 = loc("qT_ptrs"(#loc248)) +#loc543 = loc("do_ptrs"(#loc249)) +#loc544 = loc("offs_m1"(#loc250)) +#loc545 = loc("dk"(#loc252)) +#loc546 = loc("lse"(#loc253)) +#loc547 = loc("lse"(#loc254)) +#loc548 = loc("qkT"(#loc255)) +#loc549 = loc("m"(#loc256)) +#loc550 = loc("post_mod_scores"(#loc257)) +#loc551 = loc("tmp25"(#loc258)) +#loc552 = loc("tmp28"(#loc259)) +#loc553 = loc("tmp29"(#loc260)) +#loc554 = loc("tmp36"(#loc261)) +#loc555 = loc("tmp36"(#loc262)) +#loc556 = loc("tmp36"(#loc263)) +#loc557 = loc("tmp36"(#loc264)) +#loc558 = loc("tmp36"(#loc265)) +#loc559 = loc("tmp36"(#loc266)) +#loc560 = loc("tmp40"(#loc267)) +#loc561 = loc("tmp41"(#loc268)) +#loc562 = loc("tmp42"(#loc269)) +#loc563 = loc("post_mod_scores"(#loc270)) +#loc564 = loc("post_mod_scores"(#loc271)) +#loc565 = loc("pT"(#loc272)) +#loc566 = loc("pT"(#loc273)) +#loc567 = loc("pT"(#loc274)) +#loc568 = loc("dpT"(#loc275)) +#loc569 = loc("dv"(#loc276)) +#loc570 = loc("dv"(#loc277)) +#loc571 = loc("dpT"(#loc278)) +#loc572 = loc("dsT"(#loc279)) +#loc573 = loc("dsT"(#loc280)) +#loc574 = loc("dsT"(#loc281)) +#loc575 = loc("grad_scores"(#loc282)) +#loc576 = loc("dsT"(#loc283)) +#loc577 = loc("dk"(#loc284)) +#loc578 = loc("dk"(#loc285)) +#loc579 = loc("offset"(#loc286)) +#loc580 = loc("qT_ptrs"(#loc287)) +#loc581 = loc("do_ptrs"(#loc288)) +#loc582 = loc(callsite(#loc202 at #loc224)) +#loc583 = loc("dv_ptrs"(#loc290)) +#loc584 = loc("dv_ptrs"(#loc291)) +#loc585 = loc("dk"(#loc295)) +#loc586 = loc(callsite(#loc12 at #loc321)) +#loc587 = loc(callsite(#loc14 at #loc321)) +#loc588 = loc(callsite(#loc12 at #loc322)) +#loc589 = loc(callsite(#loc14 at #loc322)) +#loc590 = loc(callsite(#loc358 at #loc359)) +#loc591 = loc(callsite(#loc360 at #loc359)) +#loc592 = loc(callsite(#loc361 at #loc359)) +#loc593 = loc(callsite(#loc362 at #loc359)) +#loc594 = loc(callsite(#loc363 at #loc359)) +#loc595 = loc(callsite(#loc59 at #loc359)) +#loc596 = loc(callsite(#loc60 at #loc359)) +#loc597 = loc(callsite(#loc360 at #loc364)) +#loc598 = loc(callsite(#loc361 at #loc364)) +#loc599 = loc(callsite(#loc363 at #loc364)) +#loc600 = loc(callsite(#loc60 at #loc364)) +#loc601 = loc(callsite(#loc380 at #loc381)) +#loc602 = loc(callsite(#loc382 at #loc381)) +#loc603 = loc(callsite(#loc383 at #loc381)) +#loc604 = loc(callsite(#loc384 at #loc381)) +#loc605 = loc(callsite(#loc385 at #loc381)) +#loc606 = loc(callsite(#loc386 at #loc381)) +#loc607 = loc(callsite(#loc387 at #loc381)) +#loc608 = loc(callsite(#loc388 at #loc381)) +#loc609 = loc(callsite(#loc389 at #loc381)) +#loc610 = loc(callsite(#loc390 at #loc381)) +#loc611 = loc(callsite(#loc391 at #loc381)) +#loc612 = loc(callsite(#loc393 at #loc381)) +#loc613 = loc("offs_n2"(#loc411)) +#loc614 = loc(callsite(#loc412 at #loc381)) +#loc615 = loc(callsite(#loc413 at #loc381)) +#loc616 = loc(callsite(#loc414 at #loc381)) +#loc617 = loc(callsite(#loc442 at #loc381)) +#loc618 = loc(callsite(#loc459 at #loc381)) +#loc619 = loc(callsite(#loc380 at #loc466)) +#loc620 = loc(callsite(#loc382 at #loc466)) +#loc621 = loc(callsite(#loc383 at #loc466)) +#loc622 = loc(callsite(#loc385 at #loc466)) +#loc623 = loc(callsite(#loc386 at #loc466)) +#loc624 = loc(callsite(#loc387 at #loc466)) +#loc625 = loc(callsite(#loc388 at #loc466)) +#loc626 = loc(callsite(#loc391 at #loc466)) +#loc627 = loc(callsite(#loc393 at #loc466)) +#loc628 = loc(callsite(#loc412 at #loc466)) +#loc629 = loc(callsite(#loc413 at #loc466)) +#loc630 = loc(callsite(#loc414 at #loc466)) +#loc631 = loc(callsite(#loc442 at #loc466)) +#loc632 = loc(callsite(#loc459 at #loc466)) +#loc633 = loc(callsite(#loc358 at #loc472)) +#loc634 = loc(callsite(#loc360 at #loc472)) +#loc635 = loc(callsite(#loc361 at #loc472)) +#loc636 = loc(callsite(#loc362 at #loc472)) +#loc637 = loc(callsite(#loc363 at #loc472)) +#loc638 = loc(callsite(#loc59 at #loc472)) +#loc639 = loc(callsite(#loc60 at #loc472)) +#loc640 = loc(callsite(#loc361 at #loc473)) +#loc641 = loc(callsite(#loc363 at #loc473)) +#loc642 = loc(callsite(#loc60 at #loc473)) +#loc643 = loc(callsite(#loc485 at #loc190)) +#loc644 = loc(callsite(#loc486 at #loc190)) +#loc645 = loc(callsite(#loc487 at #loc190)) +#loc646 = loc(callsite(#loc488 at #loc190)) +#loc647 = loc(callsite(#loc489 at #loc190)) +#loc648 = loc(callsite(#loc490 at #loc190)) +#loc649 = loc(callsite(#loc491 at #loc190)) +#loc650 = loc(callsite(#loc492 at #loc190)) +#loc651 = loc(callsite(#loc493 at #loc190)) +#loc652 = loc(callsite(#loc494 at #loc190)) +#loc653 = loc(callsite(#loc495 at #loc190)) +#loc654 = loc(callsite(#loc496 at #loc497)) +#loc655 = loc(callsite(#loc498 at #loc497)) +#loc656 = loc(callsite(#loc499 at #loc497)) +#loc657 = loc(callsite(#loc500 at #loc497)) +#loc658 = loc(callsite(#loc501 at #loc497)) +#loc659 = loc(callsite(#loc502 at #loc497)) +#loc660 = loc(callsite(#loc503 at #loc497)) +#loc661 = loc(callsite(#loc504 at #loc497)) +#loc662 = loc(callsite(#loc505 at #loc497)) +#loc663 = loc(callsite(#loc506 at #loc497)) +#loc664 = loc(callsite(#loc507 at #loc497)) +#loc665 = loc(callsite(#loc508 at #loc497)) +#loc666 = loc(callsite(#loc509 at #loc497)) +#loc667 = loc(callsite(#loc510 at #loc497)) +#loc668 = loc(callsite(#loc511 at #loc497)) +#loc669 = loc(callsite(#loc512 at #loc497)) +#loc670 = loc(callsite(#loc485 at #loc224)) +#loc671 = loc(callsite(#loc486 at #loc224)) +#loc672 = loc(callsite(#loc489 at #loc224)) +#loc673 = loc(callsite(#loc490 at #loc224)) +#loc674 = loc(callsite(#loc492 at #loc224)) +#loc675 = loc(callsite(#loc495 at #loc224)) +#loc676 = loc(callsite(#loc519 at #loc497)) +#loc677 = loc("dk"(#loc520)) +#loc678 = loc(callsite(#loc535 at #loc190)) +#loc679 = loc(callsite(#loc536 at #loc190)) +#loc680 = loc(callsite(#loc537 at #loc497)) +#loc681 = loc(callsite(#loc538 at #loc497)) +#loc682 = loc(callsite(#loc539 at #loc497)) +#loc683 = loc(callsite(#loc540 at #loc497)) +#loc684 = loc("dv"(#loc541)) +#loc685 = loc(callsite(#loc542 at #loc190)) +#loc686 = loc(callsite(#loc543 at #loc190)) +#loc687 = loc(callsite(#loc544 at #loc190)) +#loc688 = loc(callsite(#loc545 at #loc497)) +#loc689 = loc(callsite(#loc546 at #loc497)) +#loc690 = loc(callsite(#loc547 at #loc497)) +#loc691 = loc(callsite(#loc548 at #loc497)) +#loc692 = loc(callsite(#loc549 at #loc497)) +#loc693 = loc(callsite(#loc550 at #loc497)) +#loc694 = loc(callsite(#loc551 at #loc497)) +#loc695 = loc(callsite(#loc552 at #loc497)) +#loc696 = loc(callsite(#loc553 at #loc497)) +#loc697 = loc(callsite(#loc554 at #loc497)) +#loc698 = loc(callsite(#loc555 at #loc497)) +#loc699 = loc(callsite(#loc556 at #loc497)) +#loc700 = loc(callsite(#loc557 at #loc497)) +#loc701 = loc(callsite(#loc558 at #loc497)) +#loc702 = loc(callsite(#loc559 at #loc497)) +#loc703 = loc(callsite(#loc560 at #loc497)) +#loc704 = loc(callsite(#loc561 at #loc497)) +#loc705 = loc(callsite(#loc562 at #loc497)) +#loc706 = loc(callsite(#loc563 at #loc497)) +#loc707 = loc(callsite(#loc564 at #loc497)) +#loc708 = loc(callsite(#loc565 at #loc497)) +#loc709 = loc(callsite(#loc566 at #loc497)) +#loc710 = loc(callsite(#loc567 at #loc497)) +#loc711 = loc(callsite(#loc568 at #loc497)) +#loc712 = loc(callsite(#loc569 at #loc497)) +#loc713 = loc(callsite(#loc570 at #loc497)) +#loc714 = loc(callsite(#loc571 at #loc497)) +#loc715 = loc(callsite(#loc572 at #loc497)) +#loc716 = loc(callsite(#loc573 at #loc497)) +#loc717 = loc(callsite(#loc574 at #loc497)) +#loc718 = loc(callsite(#loc575 at #loc497)) +#loc719 = loc(callsite(#loc576 at #loc497)) +#loc720 = loc(callsite(#loc577 at #loc497)) +#loc721 = loc(callsite(#loc578 at #loc497)) +#loc722 = loc(callsite(#loc579 at #loc190)) +#loc723 = loc(callsite(#loc580 at #loc190)) +#loc724 = loc(callsite(#loc581 at #loc190)) +#loc725 = loc(callsite(#loc535 at #loc224)) +#loc726 = loc(callsite(#loc488 at #loc224)) +#loc727 = loc(callsite(#loc536 at #loc224)) +#loc728 = loc(callsite(#loc491 at #loc224)) +#loc729 = loc(callsite(#loc496 at #loc582)) +#loc730 = loc(callsite(#loc539 at #loc582)) +#loc731 = loc(callsite(#loc512 at #loc582)) +#loc732 = loc(callsite(#loc540 at #loc582)) +#loc733 = loc(callsite(#loc498 at #loc582)) +#loc734 = loc(callsite(#loc537 at #loc582)) +#loc735 = loc(callsite(#loc538 at #loc582)) +#loc736 = loc(callsite(#loc542 at #loc224)) +#loc737 = loc(callsite(#loc543 at #loc224)) +#loc738 = loc(callsite(#loc544 at #loc224)) +#loc739 = loc(callsite(#loc545 at #loc582)) +#loc740 = loc(callsite(#loc546 at #loc582)) +#loc741 = loc(callsite(#loc547 at #loc582)) +#loc742 = loc(callsite(#loc519 at #loc582)) +#loc743 = loc(callsite(#loc548 at #loc582)) +#loc744 = loc(callsite(#loc550 at #loc582)) +#loc745 = loc(callsite(#loc564 at #loc582)) +#loc746 = loc(callsite(#loc565 at #loc582)) +#loc747 = loc(callsite(#loc566 at #loc582)) +#loc748 = loc(callsite(#loc567 at #loc582)) +#loc749 = loc(callsite(#loc568 at #loc582)) +#loc750 = loc(callsite(#loc569 at #loc582)) +#loc751 = loc(callsite(#loc570 at #loc582)) +#loc752 = loc(callsite(#loc571 at #loc582)) +#loc753 = loc(callsite(#loc572 at #loc582)) +#loc754 = loc(callsite(#loc573 at #loc582)) +#loc755 = loc(callsite(#loc574 at #loc582)) +#loc756 = loc(callsite(#loc575 at #loc582)) +#loc757 = loc(callsite(#loc577 at #loc582)) +#loc758 = loc(callsite(#loc578 at #loc582)) +#loc759 = loc(callsite(#loc579 at #loc224)) +#loc760 = loc(callsite(#loc580 at #loc224)) +#loc761 = loc(callsite(#loc581 at #loc224)) +#loc762 = loc(callsite(#loc12 at #loc609)) +#loc763 = loc(callsite(#loc14 at #loc609)) +#loc764 = loc(callsite(#loc392 at #loc612)) +#loc765 = loc(callsite(#loc394 at #loc612)) +#loc766 = loc(callsite(#loc395 at #loc612)) +#loc767 = loc(callsite(#loc396 at #loc612)) +#loc768 = loc(callsite(#loc397 at #loc612)) +#loc769 = loc(callsite(#loc398 at #loc612)) +#loc770 = loc(callsite(#loc399 at #loc612)) +#loc771 = loc(callsite(#loc400 at #loc612)) +#loc772 = loc(callsite(#loc401 at #loc612)) +#loc773 = loc(callsite(#loc402 at #loc612)) +#loc774 = loc(callsite(#loc403 at #loc612)) +#loc775 = loc(callsite(#loc404 at #loc612)) +#loc776 = loc(callsite(#loc405 at #loc612)) +#loc777 = loc(callsite(#loc406 at #loc612)) +#loc778 = loc(callsite(#loc407 at #loc612)) +#loc779 = loc(callsite(#loc408 at #loc612)) +#loc780 = loc(callsite(#loc409 at #loc612)) +#loc781 = loc(callsite(#loc410 at #loc612)) +#loc782 = loc("kT_ptrs"(#loc613)) +#loc783 = loc(callsite(#loc415 at #loc612)) +#loc784 = loc(callsite(#loc416 at #loc612)) +#loc785 = loc(callsite(#loc417 at #loc612)) +#loc786 = loc(callsite(#loc418 at #loc612)) +#loc787 = loc(callsite(#loc419 at #loc612)) +#loc788 = loc(callsite(#loc420 at #loc612)) +#loc789 = loc(callsite(#loc421 at #loc612)) +#loc790 = loc(callsite(#loc422 at #loc612)) +#loc791 = loc(callsite(#loc423 at #loc612)) +#loc792 = loc(callsite(#loc424 at #loc612)) +#loc793 = loc(callsite(#loc425 at #loc612)) +#loc794 = loc(callsite(#loc426 at #loc612)) +#loc795 = loc(callsite(#loc427 at #loc612)) +#loc796 = loc(callsite(#loc428 at #loc612)) +#loc797 = loc(callsite(#loc429 at #loc612)) +#loc798 = loc(callsite(#loc430 at #loc612)) +#loc799 = loc(callsite(#loc431 at #loc612)) +#loc800 = loc(callsite(#loc432 at #loc612)) +#loc801 = loc(callsite(#loc433 at #loc612)) +#loc802 = loc(callsite(#loc434 at #loc612)) +#loc803 = loc(callsite(#loc435 at #loc612)) +#loc804 = loc(callsite(#loc436 at #loc612)) +#loc805 = loc(callsite(#loc437 at #loc612)) +#loc806 = loc(callsite(#loc438 at #loc612)) +#loc807 = loc(callsite(#loc439 at #loc612)) +#loc808 = loc(callsite(#loc440 at #loc612)) +#loc809 = loc(callsite(#loc441 at #loc617)) +#loc810 = loc(callsite(#loc443 at #loc617)) +#loc811 = loc(callsite(#loc444 at #loc617)) +#loc812 = loc(callsite(#loc445 at #loc617)) +#loc813 = loc(callsite(#loc446 at #loc617)) +#loc814 = loc(callsite(#loc447 at #loc617)) +#loc815 = loc(callsite(#loc448 at #loc617)) +#loc816 = loc(callsite(#loc449 at #loc617)) +#loc817 = loc(callsite(#loc450 at #loc617)) +#loc818 = loc(callsite(#loc451 at #loc617)) +#loc819 = loc(callsite(#loc452 at #loc617)) +#loc820 = loc(callsite(#loc453 at #loc617)) +#loc821 = loc(callsite(#loc454 at #loc617)) +#loc822 = loc(callsite(#loc455 at #loc617)) +#loc823 = loc(callsite(#loc456 at #loc617)) +#loc824 = loc(callsite(#loc457 at #loc617)) +#loc825 = loc(callsite(#loc458 at #loc617)) +#loc826 = loc(callsite(#loc392 at #loc627)) +#loc827 = loc(callsite(#loc410 at #loc627)) +#loc828 = loc(callsite(#loc415 at #loc627)) +#loc829 = loc(callsite(#loc416 at #loc627)) +#loc830 = loc(callsite(#loc417 at #loc627)) +#loc831 = loc(callsite(#loc419 at #loc627)) +#loc832 = loc(callsite(#loc433 at #loc627)) +#loc833 = loc(callsite(#loc407 at #loc627)) +#loc834 = loc(callsite(#loc434 at #loc627)) +#loc835 = loc(callsite(#loc435 at #loc627)) +#loc836 = loc(callsite(#loc409 at #loc627)) +#loc837 = loc(callsite(#loc436 at #loc627)) +#loc838 = loc(callsite(#loc437 at #loc627)) +#loc839 = loc(callsite(#loc439 at #loc627)) +#loc840 = loc(callsite(#loc440 at #loc627)) +#loc841 = loc(callsite(#loc441 at #loc631)) +#loc842 = loc(callsite(#loc443 at #loc631)) +#loc843 = loc(callsite(#loc444 at #loc631)) +#loc844 = loc(callsite(#loc445 at #loc631)) +#loc845 = loc(callsite(#loc446 at #loc631)) +#loc846 = loc(callsite(#loc447 at #loc631)) +#loc847 = loc(callsite(#loc448 at #loc631)) +#loc848 = loc(callsite(#loc449 at #loc631)) +#loc849 = loc(callsite(#loc450 at #loc631)) +#loc850 = loc(callsite(#loc451 at #loc631)) +#loc851 = loc(callsite(#loc452 at #loc631)) +#loc852 = loc(callsite(#loc453 at #loc631)) +#loc853 = loc(callsite(#loc454 at #loc631)) +#loc854 = loc(callsite(#loc455 at #loc631)) +#loc855 = loc(callsite(#loc456 at #loc631)) +#loc856 = loc(callsite(#loc457 at #loc631)) +#loc857 = loc(callsite(#loc458 at #loc631)) +#loc858 = loc(callsite(#loc12 at #loc651)) +#loc859 = loc(callsite(#loc14 at #loc651)) +#loc860 = loc(callsite(#loc89 at #loc654)) +#loc861 = loc(callsite(#loc92 at #loc656)) +#loc862 = loc(callsite(#loc59 at #loc669)) +#loc863 = loc(callsite(#loc109 at #loc654)) +#loc864 = loc(callsite(#loc60 at #loc669)) +#loc865 = loc("offs_m1"(#loc684)) +#loc866 = loc(callsite(#loc115 at #loc654)) +#loc867 = loc(callsite(#loc251 at #loc669)) +#loc868 = loc(callsite(#loc92 at #loc692)) +#loc869 = loc(callsite(#loc441 at #loc722)) +#loc870 = loc(callsite(#loc443 at #loc722)) +#loc871 = loc(callsite(#loc444 at #loc722)) +#loc872 = loc(callsite(#loc445 at #loc722)) +#loc873 = loc(callsite(#loc446 at #loc722)) +#loc874 = loc(callsite(#loc447 at #loc722)) +#loc875 = loc(callsite(#loc448 at #loc722)) +#loc876 = loc(callsite(#loc449 at #loc722)) +#loc877 = loc(callsite(#loc450 at #loc722)) +#loc878 = loc(callsite(#loc451 at #loc722)) +#loc879 = loc(callsite(#loc452 at #loc722)) +#loc880 = loc(callsite(#loc453 at #loc722)) +#loc881 = loc(callsite(#loc454 at #loc722)) +#loc882 = loc(callsite(#loc455 at #loc722)) +#loc883 = loc(callsite(#loc456 at #loc722)) +#loc884 = loc(callsite(#loc457 at #loc722)) +#loc885 = loc(callsite(#loc458 at #loc722)) +#loc886 = loc(callsite(#loc109 at #loc729)) +#loc887 = loc(callsite(#loc60 at #loc731)) +#loc888 = loc(callsite(#loc89 at #loc729)) +#loc889 = loc(callsite(#loc59 at #loc731)) +#loc890 = loc(callsite(#loc115 at #loc729)) +#loc891 = loc(callsite(#loc251 at #loc731)) +#loc892 = loc(callsite(#loc441 at #loc759)) +#loc893 = loc(callsite(#loc443 at #loc759)) +#loc894 = loc(callsite(#loc444 at #loc759)) +#loc895 = loc(callsite(#loc445 at #loc759)) +#loc896 = loc(callsite(#loc446 at #loc759)) +#loc897 = loc(callsite(#loc447 at #loc759)) +#loc898 = loc(callsite(#loc448 at #loc759)) +#loc899 = loc(callsite(#loc449 at #loc759)) +#loc900 = loc(callsite(#loc450 at #loc759)) +#loc901 = loc(callsite(#loc451 at #loc759)) +#loc902 = loc(callsite(#loc452 at #loc759)) +#loc903 = loc(callsite(#loc453 at #loc759)) +#loc904 = loc(callsite(#loc454 at #loc759)) +#loc905 = loc(callsite(#loc455 at #loc759)) +#loc906 = loc(callsite(#loc456 at #loc759)) +#loc907 = loc(callsite(#loc457 at #loc759)) +#loc908 = loc(callsite(#loc458 at #loc759)) +#loc909 = loc(callsite(#loc89 at #loc764)) +#loc910 = loc(callsite(#loc92 at #loc765)) +#loc911 = loc(callsite(#loc109 at #loc764)) +#loc912 = loc(callsite(#loc109 at #loc781)) +#loc913 = loc("vT_ptrs"(#loc782)) +#loc914 = loc(callsite(#loc115 at #loc764)) +#loc915 = loc(callsite(#loc92 at #loc786)) +#loc916 = loc(callsite(#loc109 at #loc826)) +#loc917 = loc(callsite(#loc109 at #loc827)) +#loc918 = loc(callsite(#loc89 at #loc826)) +#loc919 = loc(callsite(#loc115 at #loc826)) +#loc920 = loc("qT_ptrs"(#loc865)) +#loc921 = loc(callsite(#loc913 at #loc381)) +#loc922 = loc(callsite(#loc913 at #loc466)) +#loc923 = loc("do_ptrs"(#loc920)) +#loc924 = loc(callsite(#loc923 at #loc190)) +#loc925 = loc(callsite(#loc923 at #loc224)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttir new file mode 100644 index 0000000000000000000000000000000000000000..a1aa62385a6b5cd14a9e304b4044d89efb988436 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/PJJES3QEVXF7MPESQRKFQ4D55L4Y7YJPTGXSVMGRCNUVXD3MMXGQ/triton_tem_fused_mul_1.ttir @@ -0,0 +1,1530 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":18:0) +#loc304 = loc("arg_Q"(#loc)) +#loc305 = loc("arg_K"(#loc)) +#loc306 = loc("arg_V"(#loc)) +#loc307 = loc("arg_LSE"(#loc)) +#loc308 = loc("arg_DELTA"(#loc)) +#loc309 = loc("arg_DO"(#loc)) +#loc310 = loc("arg_DQ"(#loc)) +#loc311 = loc("arg_DV"(#loc)) +#loc312 = loc("arg_KV_NUM_BLKS"(#loc)) +#loc313 = loc("arg_KV_IDX"(#loc)) +#loc314 = loc("arg_Q_NUM_BLKS"(#loc)) +#loc315 = loc("arg_Q_IDX"(#loc)) +#loc316 = loc("arg_FULL_KV_NUM_BLKS"(#loc)) +#loc317 = loc("arg_FULL_KV_IDX"(#loc)) +#loc318 = loc("arg_FULL_Q_NUM_BLKS"(#loc)) +#loc319 = loc("arg_FULL_Q_IDX"(#loc)) +#loc320 = loc("out_ptr0"(#loc)) +#loc321 = loc("ks0"(#loc)) +#loc322 = loc("ks1"(#loc)) +module { + tt.func public @triton_tem_fused_mul_1(%arg_Q: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q"(#loc)), %arg_K: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_K"(#loc)), %arg_V: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_V"(#loc)), %arg_LSE: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_LSE"(#loc)), %arg_DELTA: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DELTA"(#loc)), %arg_DO: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DO"(#loc)), %arg_DQ: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DQ"(#loc)), %arg_DV: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_DV"(#loc)), %arg_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_NUM_BLKS"(#loc)), %arg_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_KV_IDX"(#loc)), %arg_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_NUM_BLKS"(#loc)), %arg_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_Q_IDX"(#loc)), %arg_FULL_KV_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_NUM_BLKS"(#loc)), %arg_FULL_KV_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_KV_IDX"(#loc)), %arg_FULL_Q_NUM_BLKS: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_NUM_BLKS"(#loc)), %arg_FULL_Q_IDX: !tt.ptr {tt.divisibility = 16 : i32} loc("arg_FULL_Q_IDX"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %ks0: i32 loc("ks0"(#loc)), %ks1: i32 loc("ks1"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<128> : tensor<64x1xi32> loc(#loc1) + %cst_0 = arith.constant dense<4096> : tensor<1x64xi32> loc(#loc1) + %cst_1 = arith.constant dense<1024> : tensor<1x64xi32> loc(#loc1) + %cst_2 = arith.constant dense<128> : tensor<128x1xi32> loc(#loc1) + %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x128xbf16> loc(#loc1) + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64xf32> loc(#loc1) + %cst_5 = arith.constant dense<0xFF800000> : tensor<64xf32> loc(#loc1) + %c2_i32 = arith.constant 2 : i32 loc(#loc1) + %c64_i32 = arith.constant 64 : i32 loc(#loc1) + %c63_i32 = arith.constant 63 : i32 loc(#loc1) + %cst_6 = arith.constant dense<16> : tensor<1x64xi32> loc(#loc1) + %cst_7 = arith.constant dense<16> : tensor<128x1xi32> loc(#loc1) + %cst_8 = arith.constant dense<1.44269502> : tensor<128x64xf32> loc(#loc1) + %cst_9 = arith.constant dense<1> : tensor<1x64xi32> loc(#loc1) + %cst_10 = arith.constant dense<1> : tensor<128x1xi32> loc(#loc1) + %cst_11 = arith.constant dense<0> : tensor<128x1xi32> loc(#loc1) + %cst_12 = arith.constant dense<0> : tensor<1x64xi32> loc(#loc1) + %cst_13 = arith.constant dense<0xFF800000> : tensor<128x64xf32> loc(#loc1) + %cst_14 = arith.constant dense<0.0883883461> : tensor<128x64xf32> loc(#loc1) + %cst_15 = arith.constant dense<0.000000e+00> : tensor<128x64xf32> loc(#loc1) + %cst_16 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> loc(#loc1) + %cst_17 = arith.constant dense<0.000000e+00> : tensor<128x128xbf16> loc(#loc1) + %cst_18 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> loc(#loc1) + %c127_i32 = arith.constant 127 : i32 loc(#loc1) + %c0_i32 = arith.constant 0 : i32 loc(#loc1) + %cst_19 = arith.constant dense<1024> : tensor<128x1xi32> loc(#loc1) + %cst_20 = arith.constant dense<128> : tensor<1x128xi32> loc(#loc1) + %cst_21 = arith.constant dense<0.0883883461> : tensor<128x128xf32> loc(#loc1) + %cst_22 = arith.constant dense<4096> : tensor<128x1xi32> loc(#loc1) + %cst_23 = arith.constant dense<0.000000e+00> : tensor<128xf32> loc(#loc1) + %cst_24 = arith.constant dense<0xFF800000> : tensor<128xf32> loc(#loc1) + %c4_i32 = arith.constant 4 : i32 loc(#loc1) + %HQ = arith.constant 32 : i32 loc(#loc323) + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %c1_i32 = arith.constant 1 : i32 loc(#loc1) + %c128_i32 = arith.constant 128 : i32 loc(#loc1) + %c4096_i32 = arith.constant 4096 : i32 loc(#loc1) + %0 = arith.muli %ks0, %c4096_i32 : i32 loc(#loc3) + %1 = arith.cmpi sle, %ks0, %c1_i32 : i32 loc(#loc4) + %2 = arith.extui %1 : i1 to i32 loc(#loc5) + %3 = arith.cmpi sgt, %ks0, %c1_i32 : i32 loc(#loc6) + %4 = arith.extui %3 : i1 to i32 loc(#loc7) + %5 = arith.muli %ks0, %4 : i32 loc(#loc7) + %6 = arith.addi %2, %5 : i32 loc(#loc8) + %7 = arith.muli %6, %c4096_i32 : i32 loc(#loc9) + %8 = arith.muli %6, %c128_i32 : i32 loc(#loc10) + %9 = arith.muli %ks1, %c1024_i32 : i32 loc(#loc11) + %pid = tt.get_program_id x : i32 loc(#loc324) + %NUM_KV_BLOCKS = arith.addi %ks1, %c127_i32 : i32 loc(#loc592) + %NUM_KV_BLOCKS_25 = arith.divsi %NUM_KV_BLOCKS, %c128_i32 : i32 loc(#loc593) + %NUM_Q_BLOCKS = arith.addi %ks0, %c127_i32 : i32 loc(#loc594) + %NUM_Q_BLOCKS_26 = arith.divsi %NUM_Q_BLOCKS, %c128_i32 : i32 loc(#loc595) + %off_zq = tt.get_program_id y : i32 loc(#loc327) + %off_hkv = tt.get_program_id z : i32 loc(#loc328) + %k_adj = arith.muli %off_hkv, %c128_i32 : i32 loc(#loc329) + %k_adj_27 = arith.extsi %k_adj : i32 to i64 loc(#loc330) + %dv_adj = arith.muli %9, %off_zq : i32 loc(#loc331) + %dv_adj_28 = arith.addi %k_adj, %dv_adj : i32 loc(#loc332) + %dv_adj_29 = arith.extsi %dv_adj_28 : i32 to i64 loc(#loc333) + %K = tt.addptr %arg_K, %k_adj_27 : !tt.ptr, i64 loc(#loc334) + %V = tt.addptr %arg_V, %k_adj_27 : !tt.ptr, i64 loc(#loc335) + %DV = tt.addptr %arg_DV, %dv_adj_29 : !tt.ptr, i64 loc(#loc336) + %offs_k = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> loc(#loc337) + %10 = arith.cmpi sge, %pid, %NUM_KV_BLOCKS_25 : i32 loc(#loc28) + scf.if %10 { + %off_pid = arith.subi %pid, %NUM_KV_BLOCKS_25 : i32 loc(#loc338) + %off_hq2 = arith.divsi %off_pid, %NUM_Q_BLOCKS_26 : i32 loc(#loc339) + %off_hq2_30 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc340) + %off_hq2_31 = arith.addi %off_hq2, %off_hq2_30 : i32 loc(#loc341) + %start_m2_block = arith.remsi %off_pid, %NUM_Q_BLOCKS_26 : i32 loc(#loc342) + %q_adj2 = arith.muli %off_hq2_31, %c128_i32 : i32 loc(#loc343) + %q_adj2_32 = arith.muli %0, %off_zq : i32 loc(#loc344) + %q_adj2_33 = arith.addi %q_adj2, %q_adj2_32 : i32 loc(#loc345) + %q_adj2_34 = arith.extsi %q_adj2_33 : i32 to i64 loc(#loc346) + %do_adj2 = arith.muli %8, %off_hq2_31 : i32 loc(#loc347) + %do_adj2_35 = arith.muli %7, %off_zq : i32 loc(#loc348) + %do_adj2_36 = arith.addi %do_adj2, %do_adj2_35 : i32 loc(#loc349) + %do_adj2_37 = arith.extsi %do_adj2_36 : i32 to i64 loc(#loc350) + %off_chz2 = arith.muli %off_zq, %HQ : i32 loc(#loc351) + %off_chz2_38 = arith.addi %off_chz2, %off_hq2_31 : i32 loc(#loc352) + %off_chz2_39 = arith.muli %off_chz2_38, %ks0 : i32 loc(#loc353) + %off_chz2_40 = arith.extsi %off_chz2_39 : i32 to i64 loc(#loc354) + %Q2 = tt.addptr %arg_Q, %q_adj2_34 : !tt.ptr, i64 loc(#loc355) + %DO2 = tt.addptr %arg_DO, %do_adj2_37 : !tt.ptr, i64 loc(#loc356) + %DQ2 = tt.addptr %arg_DQ, %q_adj2_34 : !tt.ptr, i64 loc(#loc357) + %LSE2 = tt.addptr %arg_LSE, %off_chz2_40 : !tt.ptr, i64 loc(#loc358) + %DELTA2 = tt.addptr %arg_DELTA, %off_chz2_40 : !tt.ptr, i64 loc(#loc359) + %start_m2 = arith.muli %start_m2_block, %c128_i32 : i32 loc(#loc360) + %offs_m2 = tt.splat %start_m2 : i32 -> tensor<128xi32> loc(#loc361) + %offs_m2_41 = arith.addi %offs_m2, %offs_k : tensor<128xi32> loc(#loc361) + %ptr = tt.expand_dims %offs_m2_41 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc596) + %ptr_42 = arith.muli %ptr, %cst_22 : tensor<128x1xi32> loc(#loc597) + %ptr_43 = tt.splat %Q2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc598) + %ptr_44 = tt.addptr %ptr_43, %ptr_42 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc598) + %ptr_45 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc599) + %ptr_46 = tt.broadcast %ptr_44 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc600) + %ptr_47 = tt.broadcast %ptr_45 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc600) + %ptr_48 = tt.addptr %ptr_46, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc600) + %q = tt.splat %ks0 : i32 -> tensor<128x1xi32> loc(#loc601) + %q_49 = arith.cmpi slt, %ptr, %q : tensor<128x1xi32> loc(#loc601) + %q_50 = tt.broadcast %q_49 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc602) + %q_51 = tt.load %ptr_48, %q_50, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc602) + %ptr_52 = arith.muli %ptr, %cst_2 : tensor<128x1xi32> loc(#loc603) + %ptr_53 = tt.splat %DO2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc604) + %ptr_54 = tt.addptr %ptr_53, %ptr_52 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc604) + %ptr_55 = tt.broadcast %ptr_54 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc605) + %ptr_56 = tt.addptr %ptr_55, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc605) + %do = tt.load %ptr_56, %q_50, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc606) + %Di = tt.splat %ks0 : i32 -> tensor<128xi32> loc(#loc369) + %Di_57 = arith.cmpi slt, %offs_m2_41, %Di : tensor<128xi32> loc(#loc369) + %Di_58 = tt.splat %DELTA2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc370) + %Di_59 = tt.addptr %Di_58, %offs_m2_41 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc370) + %Di_60 = tt.load %Di_59, %Di_57 : tensor<128x!tt.ptr> loc(#loc371) + %lse = tt.splat %LSE2 : !tt.ptr -> tensor<128x!tt.ptr> loc(#loc372) + %lse_61 = tt.addptr %lse, %offs_m2_41 : tensor<128x!tt.ptr>, tensor<128xi32> loc(#loc372) + %lse_62 = tt.load %lse_61, %Di_57 : tensor<128x!tt.ptr> loc(#loc373) + %lse_63 = arith.cmpf oeq, %lse_62, %cst_24 : tensor<128xf32> loc(#loc374) + %lse_64 = arith.select %lse_63, %cst_23, %lse_62 : tensor<128xi1>, tensor<128xf32> loc(#loc375) + %lse_65 = tt.expand_dims %lse_64 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc376) + %kv_indices = tt.addptr %arg_KV_IDX, %start_m2_block : !tt.ptr, i32 loc(#loc377) + %kv_start = tt.load %kv_indices : !tt.ptr loc(#loc378) + %kv_start_66 = arith.muli %kv_start, %c128_i32 : i32 loc(#loc379) + %sparse_kv_num_blocks = tt.addptr %arg_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc380) + %sparse_kv_num_blocks_67 = tt.load %sparse_kv_num_blocks : !tt.ptr loc(#loc381) + %offs_n2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc382) + %offs_n2_68 = tt.splat %kv_start_66 : i32 -> tensor<64xi32> loc(#loc383) + %offs_n2_69 = arith.addi %offs_n2_68, %offs_n2 : tensor<64xi32> loc(#loc383) + %kT_ptrs = tt.expand_dims %offs_n2_69 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc607) + %kT_ptrs_70 = arith.muli %kT_ptrs, %cst_1 : tensor<1x64xi32> loc(#loc608) + %kT_ptrs_71 = tt.splat %K : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc609) + %kT_ptrs_72 = tt.addptr %kT_ptrs_71, %kT_ptrs_70 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc609) + %kT_ptrs_73 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc610) + %kT_ptrs_74 = tt.broadcast %kT_ptrs_72 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc611) + %kT_ptrs_75 = tt.broadcast %kT_ptrs_73 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc611) + %kT_ptrs_76 = tt.addptr %kT_ptrs_74, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc611) + %vT_ptrs = tt.splat %V : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc612) + %vT_ptrs_77 = tt.addptr %vT_ptrs, %kT_ptrs_70 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc612) + %vT_ptrs_78 = tt.broadcast %vT_ptrs_77 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc613) + %vT_ptrs_79 = tt.addptr %vT_ptrs_78, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc613) + %hi = arith.muli %sparse_kv_num_blocks_67, %c2_i32 : i32 loc(#loc614) + %hi_80 = arith.addi %ks1, %c63_i32 : i32 loc(#loc770) + %hi_81 = arith.divsi %hi_80, %c64_i32 : i32 loc(#loc771) + %hi_82 = arith.maxsi %hi_81, %c1_i32 : i32 loc(#loc616) + %hi_83 = arith.minsi %hi, %hi_82 : i32 loc(#loc617) + %vT_ptrs_84:4 = scf.for %start_n = %c0_i32 to %hi_83 step %c1_i32 iter_args(%dq_106 = %cst_18, %offs_n2_107 = %offs_n2_69, %kT_ptrs_108 = %kT_ptrs_76, %vT_ptrs_109 = %vT_ptrs_79) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.expand_dims %offs_n2_107 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc919) + %kT_110 = tt.splat %ks1 : i32 -> tensor<1x64xi32> loc(#loc920) + %kT_111 = arith.cmpi slt, %kT, %kT_110 : tensor<1x64xi32> loc(#loc920) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc921) + %kT_113 = tt.load %kT_ptrs_108, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc921) + %qk = tt.dot %q_51, %kT_113, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc774) + %qk_114 = arith.mulf %qk, %cst_14 : tensor<128x64xf32> loc(#loc775) + %n = arith.remsi %kT, %kT_110 : tensor<1x64xi32> loc(#loc922) + %m = arith.remsi %ptr, %q : tensor<128x1xi32> loc(#loc923) + %post_mod_scores = arith.select %kT_112, %qk_114, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc778) + %tmp3 = arith.cmpi slt, %m, %cst_11 : tensor<128x1xi32> loc(#loc779) + %tmp5 = tt.broadcast %n : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc780) + %tmp5_115 = tt.broadcast %m : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc780) + %tmp5_116 = arith.cmpi sle, %tmp5, %tmp5_115 : tensor<128x64xi32> loc(#loc780) + %tmp6 = tt.broadcast %tmp3 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc781) + %tmp6_117 = arith.andi %tmp6, %tmp5_116 : tensor<128x64xi1> loc(#loc781) + %tmp7 = arith.cmpi sge, %m, %cst_11 : tensor<128x1xi32> loc(#loc782) + %tmp8 = arith.cmpi slt, %n, %cst_12 : tensor<1x64xi32> loc(#loc783) + %tmp9 = tt.broadcast %tmp7 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc784) + %tmp9_118 = tt.broadcast %tmp8 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc784) + %tmp9_119 = arith.andi %tmp9, %tmp9_118 : tensor<128x64xi1> loc(#loc784) + %tmp10 = arith.extui %tmp8 : tensor<1x64xi1> to tensor<1x64xi32> loc(#loc785) + %tmp10_120 = arith.cmpi eq, %tmp10, %cst_12 : tensor<1x64xi32> loc(#loc785) + %tmp11 = tt.broadcast %tmp10_120 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc786) + %tmp11_121 = arith.andi %tmp9, %tmp11 : tensor<128x64xi1> loc(#loc786) + %tmp14 = arith.remsi %m, %cst_7 : tensor<128x1xi32> loc(#loc787) + %tmp14_122 = arith.cmpi ne, %tmp14, %cst_11 : tensor<128x1xi32> loc(#loc788) + %tmp14_123 = arith.divsi %m, %cst_7 : tensor<128x1xi32> loc(#loc789) + %tmp14_124 = arith.subi %tmp14_123, %cst_10 : tensor<128x1xi32> loc(#loc790) + %tmp14_125 = arith.select %tmp14_122, %tmp14_124, %tmp14_123 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc791) + %tmp14_126 = arith.select %tmp3, %tmp14_125, %tmp14_123 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc792) + %tmp16 = arith.remsi %n, %cst_6 : tensor<1x64xi32> loc(#loc793) + %tmp16_127 = arith.cmpi ne, %tmp16, %cst_12 : tensor<1x64xi32> loc(#loc794) + %tmp16_128 = arith.divsi %n, %cst_6 : tensor<1x64xi32> loc(#loc795) + %tmp16_129 = arith.subi %tmp16_128, %cst_9 : tensor<1x64xi32> loc(#loc796) + %tmp16_130 = arith.select %tmp16_127, %tmp16_129, %tmp16_128 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc797) + %tmp16_131 = arith.select %tmp8, %tmp16_130, %tmp16_128 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc798) + %tmp17 = tt.broadcast %tmp14_126 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc799) + %tmp17_132 = tt.broadcast %tmp16_131 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc799) + %tmp17_133 = arith.cmpi eq, %tmp17, %tmp17_132 : tensor<128x64xi32> loc(#loc799) + %tmp18 = arith.andi %tmp11_121, %tmp17_133 : tensor<128x64xi1> loc(#loc800) + %tmp19 = arith.ori %tmp9_119, %tmp18 : tensor<128x64xi1> loc(#loc801) + %tmp20 = arith.ori %tmp6_117, %tmp19 : tensor<128x64xi1> loc(#loc802) + %post_mod_scores_134 = arith.select %tmp20, %post_mod_scores, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc803) + %post_mod_scores_135 = arith.mulf %post_mod_scores_134, %cst_8 : tensor<128x64xf32> loc(#loc804) + %p = tt.broadcast %lse_65 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc805) + %p_136 = arith.subf %post_mod_scores_135, %p : tensor<128x64xf32> loc(#loc805) + %p_137 = math.exp2 %p_136 : tensor<128x64xf32> loc(#loc806) + %vT = tt.load %vT_ptrs_109, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc924) + %dp = tt.dot %do, %vT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc808) + %ds = tt.expand_dims %Di_60 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc809) + %ds_138 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc810) + %ds_139 = arith.subf %dp, %ds_138 : tensor<128x64xf32> loc(#loc810) + %ds_140 = arith.mulf %p_137, %ds_139 : tensor<128x64xf32> loc(#loc811) + %grad_scores = arith.select %kT_112, %ds_140, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc812) + %ds_141 = arith.select %tmp20, %grad_scores, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc813) + %ds_142 = arith.truncf %ds_141 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc814) + %dq_143 = tt.trans %kT_113 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc815) + %dq_144 = tt.dot %ds_142, %dq_143, %dq_106, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc816) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc817) + %cur_block = tt.addptr %kv_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc818) + %cur_block_145 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc819) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc820) + %next_block_146 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_67 : i32 loc(#loc821) + %next_block_147 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc822) + %next_block_148 = tt.load %next_block_147, %next_block_146 evictionPolicy = evict_last : !tt.ptr loc(#loc823) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc824) + %needs_jump_149 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc825) + %needs_jump_150 = arith.cmpi eq, %needs_jump_149, %c0_i32 : i32 loc(#loc826) + %jump_to_block = arith.subi %next_block_148, %cur_block_145 : i32 loc(#loc827) + %jump_to_block_151 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc828) + %jump_to_block_152 = arith.subi %jump_to_block_151, %c64_i32 : i32 loc(#loc829) + %offset = arith.extui %needs_jump_150 : i1 to i32 loc(#loc830) + %offset_153 = arith.muli %jump_to_block_152, %offset : i32 loc(#loc830) + %offset_154 = arith.subi %c1_i32, %offset : i32 loc(#loc831) + %offset_155 = arith.muli %offset_154, %c64_i32 : i32 loc(#loc832) + %offset_156 = arith.addi %offset_153, %offset_155 : i32 loc(#loc833) + %kT_ptrs_157 = arith.muli %offset_156, %c1024_i32 : i32 loc(#loc621) + %kT_ptrs_158 = tt.splat %kT_ptrs_157 : i32 -> tensor<128x64xi32> loc(#loc622) + %kT_ptrs_159 = tt.addptr %kT_ptrs_108, %kT_ptrs_158 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc622) + %vT_ptrs_160 = tt.addptr %vT_ptrs_109, %kT_ptrs_158 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc623) + %offs_n2_161 = tt.splat %offset_156 : i32 -> tensor<64xi32> loc(#loc624) + %offs_n2_162 = arith.addi %offs_n2_107, %offs_n2_161 : tensor<64xi32> loc(#loc624) + scf.yield %dq_144, %offs_n2_162, %kT_ptrs_159, %vT_ptrs_160 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc625) + } loc(#loc930) + %kv_indices_85 = tt.addptr %arg_FULL_KV_IDX, %start_m2_block : !tt.ptr, i32 loc(#loc464) + %kv_start_86 = tt.load %kv_indices_85 : !tt.ptr loc(#loc465) + %kv_start_87 = arith.muli %kv_start_86, %c128_i32 : i32 loc(#loc466) + %sparse_kv_num_blocks_88 = tt.addptr %arg_FULL_KV_NUM_BLKS, %start_m2_block : !tt.ptr, i32 loc(#loc467) + %sparse_kv_num_blocks_89 = tt.load %sparse_kv_num_blocks_88 : !tt.ptr loc(#loc468) + %offs_n2_90 = tt.splat %kv_start_87 : i32 -> tensor<64xi32> loc(#loc469) + %offs_n2_91 = arith.addi %offs_n2_90, %offs_n2 : tensor<64xi32> loc(#loc469) + %kT_ptrs_92 = tt.expand_dims %offs_n2_91 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc626) + %kT_ptrs_93 = arith.muli %kT_ptrs_92, %cst_1 : tensor<1x64xi32> loc(#loc627) + %kT_ptrs_94 = tt.addptr %kT_ptrs_71, %kT_ptrs_93 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc628) + %kT_ptrs_95 = tt.broadcast %kT_ptrs_94 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc629) + %kT_ptrs_96 = tt.addptr %kT_ptrs_95, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc629) + %vT_ptrs_97 = tt.addptr %vT_ptrs, %kT_ptrs_93 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc630) + %vT_ptrs_98 = tt.broadcast %vT_ptrs_97 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc631) + %vT_ptrs_99 = tt.addptr %vT_ptrs_98, %kT_ptrs_75 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc631) + %hi_100 = arith.muli %sparse_kv_num_blocks_89, %c2_i32 : i32 loc(#loc632) + %hi_101 = arith.minsi %hi_100, %hi_82 : i32 loc(#loc633) + %vT_ptrs_102:4 = scf.for %start_n = %c0_i32 to %hi_101 step %c1_i32 iter_args(%dq_106 = %vT_ptrs_84#0, %offs_n2_107 = %offs_n2_91, %kT_ptrs_108 = %kT_ptrs_96, %vT_ptrs_109 = %vT_ptrs_99) -> (tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr>) : i32 { + %kT = tt.expand_dims %offs_n2_107 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc925) + %kT_110 = tt.splat %ks1 : i32 -> tensor<1x64xi32> loc(#loc926) + %kT_111 = arith.cmpi slt, %kT, %kT_110 : tensor<1x64xi32> loc(#loc926) + %kT_112 = tt.broadcast %kT_111 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc927) + %kT_113 = tt.load %kT_ptrs_108, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc927) + %qk = tt.dot %q_51, %kT_113, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc835) + %qk_114 = arith.mulf %qk, %cst_14 : tensor<128x64xf32> loc(#loc836) + %post_mod_scores = arith.select %kT_112, %qk_114, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc837) + %post_mod_scores_115 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32> loc(#loc838) + %p = tt.broadcast %lse_65 : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc839) + %p_116 = arith.subf %post_mod_scores_115, %p : tensor<128x64xf32> loc(#loc839) + %p_117 = math.exp2 %p_116 : tensor<128x64xf32> loc(#loc840) + %vT = tt.load %vT_ptrs_109, %kT_112, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc928) + %dp = tt.dot %do, %vT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc842) + %ds = tt.expand_dims %Di_60 {axis = 1 : i32} : tensor<128xf32> -> tensor<128x1xf32> loc(#loc843) + %ds_118 = tt.broadcast %ds : tensor<128x1xf32> -> tensor<128x64xf32> loc(#loc844) + %ds_119 = arith.subf %dp, %ds_118 : tensor<128x64xf32> loc(#loc844) + %ds_120 = arith.mulf %p_117, %ds_119 : tensor<128x64xf32> loc(#loc845) + %grad_scores = arith.select %kT_112, %ds_120, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc846) + %ds_121 = arith.truncf %grad_scores : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc847) + %dq_122 = tt.trans %kT_113 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc848) + %dq_123 = tt.dot %ds_121, %dq_122, %dq_106, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc849) + %cur_block_idx = arith.divsi %start_n, %c2_i32 : i32 loc(#loc850) + %cur_block = tt.addptr %kv_indices_85, %cur_block_idx : !tt.ptr, i32 loc(#loc851) + %cur_block_124 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc852) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc853) + %next_block_125 = arith.cmpi slt, %next_block, %sparse_kv_num_blocks_89 : i32 loc(#loc854) + %next_block_126 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc855) + %next_block_127 = tt.load %next_block_126, %next_block_125 evictionPolicy = evict_last : !tt.ptr loc(#loc856) + %needs_jump = arith.addi %start_n, %c1_i32 : i32 loc(#loc857) + %needs_jump_128 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc858) + %needs_jump_129 = arith.cmpi eq, %needs_jump_128, %c0_i32 : i32 loc(#loc859) + %jump_to_block = arith.subi %next_block_127, %cur_block_124 : i32 loc(#loc860) + %jump_to_block_130 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc861) + %jump_to_block_131 = arith.subi %jump_to_block_130, %c64_i32 : i32 loc(#loc862) + %offset = arith.extui %needs_jump_129 : i1 to i32 loc(#loc863) + %offset_132 = arith.muli %jump_to_block_131, %offset : i32 loc(#loc863) + %offset_133 = arith.subi %c1_i32, %offset : i32 loc(#loc864) + %offset_134 = arith.muli %offset_133, %c64_i32 : i32 loc(#loc865) + %offset_135 = arith.addi %offset_132, %offset_134 : i32 loc(#loc866) + %kT_ptrs_136 = arith.muli %offset_135, %c1024_i32 : i32 loc(#loc636) + %kT_ptrs_137 = tt.splat %kT_ptrs_136 : i32 -> tensor<128x64xi32> loc(#loc637) + %kT_ptrs_138 = tt.addptr %kT_ptrs_108, %kT_ptrs_137 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc637) + %vT_ptrs_139 = tt.addptr %vT_ptrs_109, %kT_ptrs_137 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc638) + %offs_n2_140 = tt.splat %offset_135 : i32 -> tensor<64xi32> loc(#loc639) + %offs_n2_141 = arith.addi %offs_n2_107, %offs_n2_140 : tensor<64xi32> loc(#loc639) + scf.yield %dq_123, %offs_n2_141, %kT_ptrs_138, %vT_ptrs_139 : tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<128x64x!tt.ptr> loc(#loc640) + } loc(#loc931) + %dq_ptrs = tt.splat %DQ2 : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc471) + %dq_ptrs_103 = tt.addptr %dq_ptrs, %ptr_42 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc471) + %dq_ptrs_104 = tt.broadcast %dq_ptrs_103 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc472) + %dq_ptrs_105 = tt.addptr %dq_ptrs_104, %ptr_47 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc472) + %dq = arith.mulf %vT_ptrs_102#0, %cst_21 : tensor<128x128xf32> loc(#loc473) + %11 = arith.cmpi slt, %ptr_45, %cst_20 : tensor<1x128xi32> loc(#loc173) + %12 = tt.broadcast %11 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc174) + %13 = arith.andi %q_50, %12 : tensor<128x128xi1> loc(#loc174) + %14 = arith.truncf %dq : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc175) + tt.store %dq_ptrs_105, %14, %13 : tensor<128x128x!tt.ptr> loc(#loc175) + } else { + %start_n1 = arith.muli %pid, %c128_i32 : i32 loc(#loc474) + %offs_n1 = tt.splat %start_n1 : i32 -> tensor<128xi32> loc(#loc475) + %offs_n1_30 = arith.addi %offs_n1, %offs_k : tensor<128xi32> loc(#loc475) + %ptr = tt.expand_dims %offs_n1_30 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc641) + %ptr_31 = arith.muli %ptr, %cst_19 : tensor<128x1xi32> loc(#loc642) + %ptr_32 = tt.splat %K : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc643) + %ptr_33 = tt.addptr %ptr_32, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc643) + %ptr_34 = tt.expand_dims %offs_k {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> loc(#loc644) + %ptr_35 = tt.broadcast %ptr_33 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc645) + %ptr_36 = tt.broadcast %ptr_34 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc645) + %ptr_37 = tt.addptr %ptr_35, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc645) + %k = tt.splat %ks1 : i32 -> tensor<128x1xi32> loc(#loc646) + %k_38 = arith.cmpi slt, %ptr, %k : tensor<128x1xi32> loc(#loc646) + %k_39 = tt.broadcast %k_38 : tensor<128x1xi1> -> tensor<128x128xi1> loc(#loc647) + %k_40 = tt.load %ptr_37, %k_39, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc647) + %ptr_41 = tt.splat %V : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc648) + %ptr_42 = tt.addptr %ptr_41, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc648) + %ptr_43 = tt.broadcast %ptr_42 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc649) + %ptr_44 = tt.addptr %ptr_43, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc649) + %v = tt.load %ptr_44, %k_39, %cst_17 : tensor<128x128x!tt.ptr> loc(#loc650) + %dk:2 = scf.for %off_g = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%dv = %cst_18, %dk_49 = %cst_18) -> (tensor<128x128xf32>, tensor<128x128xf32>) : i32 { + %off_hq1 = arith.muli %off_hkv, %c4_i32 : i32 loc(#loc479) + %off_hq1_50 = arith.addi %off_hq1, %off_g : i32 loc(#loc480) + %q_adj1 = arith.muli %off_hq1_50, %c128_i32 : i32 loc(#loc481) + %q_adj1_51 = arith.muli %0, %off_zq : i32 loc(#loc482) + %q_adj1_52 = arith.addi %q_adj1, %q_adj1_51 : i32 loc(#loc483) + %q_adj1_53 = arith.extsi %q_adj1_52 : i32 to i64 loc(#loc484) + %do_adj1 = arith.muli %8, %off_hq1_50 : i32 loc(#loc485) + %do_adj1_54 = arith.muli %7, %off_zq : i32 loc(#loc486) + %do_adj1_55 = arith.addi %do_adj1, %do_adj1_54 : i32 loc(#loc487) + %do_adj1_56 = arith.extsi %do_adj1_55 : i32 to i64 loc(#loc488) + %off_chz1 = arith.muli %off_zq, %HQ : i32 loc(#loc489) + %off_chz1_57 = arith.addi %off_chz1, %off_hq1_50 : i32 loc(#loc490) + %off_chz1_58 = arith.muli %off_chz1_57, %ks0 : i32 loc(#loc491) + %off_chz1_59 = arith.extsi %off_chz1_58 : i32 to i64 loc(#loc492) + %Q1 = tt.addptr %arg_Q, %q_adj1_53 : !tt.ptr, i64 loc(#loc493) + %DO1 = tt.addptr %arg_DO, %do_adj1_56 : !tt.ptr, i64 loc(#loc494) + %LSE1 = tt.addptr %arg_LSE, %off_chz1_59 : !tt.ptr, i64 loc(#loc495) + %DELTA1 = tt.addptr %arg_DELTA, %off_chz1_59 : !tt.ptr, i64 loc(#loc496) + %q_indices = tt.addptr %arg_Q_IDX, %pid : !tt.ptr, i32 loc(#loc497) + %q_start = tt.load %q_indices : !tt.ptr loc(#loc498) + %q_start_60 = arith.muli %q_start, %c128_i32 : i32 loc(#loc499) + %sparse_q_num_blocks = tt.addptr %arg_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc500) + %sparse_q_num_blocks_61 = tt.load %sparse_q_num_blocks : !tt.ptr loc(#loc501) + %offs_m1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> loc(#loc502) + %offs_m1_62 = tt.splat %q_start_60 : i32 -> tensor<64xi32> loc(#loc503) + %offs_m1_63 = arith.addi %offs_m1_62, %offs_m1 : tensor<64xi32> loc(#loc503) + %qT_ptrs = tt.expand_dims %offs_m1_63 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc652) + %qT_ptrs_64 = arith.muli %qT_ptrs, %cst_0 : tensor<1x64xi32> loc(#loc653) + %qT_ptrs_65 = tt.splat %Q1 : !tt.ptr -> tensor<1x64x!tt.ptr> loc(#loc654) + %qT_ptrs_66 = tt.addptr %qT_ptrs_65, %qT_ptrs_64 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc654) + %qT_ptrs_67 = tt.expand_dims %offs_k {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> loc(#loc655) + %qT_ptrs_68 = tt.broadcast %qT_ptrs_66 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc656) + %qT_ptrs_69 = tt.broadcast %qT_ptrs_67 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc656) + %qT_ptrs_70 = tt.addptr %qT_ptrs_68, %qT_ptrs_69 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc656) + %do_ptrs = tt.expand_dims %offs_m1_63 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc657) + %do_ptrs_71 = arith.muli %do_ptrs, %cst : tensor<64x1xi32> loc(#loc658) + %do_ptrs_72 = tt.splat %DO1 : !tt.ptr -> tensor<64x1x!tt.ptr> loc(#loc659) + %do_ptrs_73 = tt.addptr %do_ptrs_72, %do_ptrs_71 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc659) + %do_ptrs_74 = tt.broadcast %do_ptrs_73 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc660) + %do_ptrs_75 = tt.broadcast %ptr_34 : tensor<1x128xi32> -> tensor<64x128xi32> loc(#loc660) + %do_ptrs_76 = tt.addptr %do_ptrs_74, %do_ptrs_75 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc660) + %hi = arith.muli %sparse_q_num_blocks_61, %c2_i32 : i32 loc(#loc661) + %hi_77 = arith.addi %ks0, %c63_i32 : i32 loc(#loc867) + %hi_78 = arith.divsi %hi_77, %c64_i32 : i32 loc(#loc868) + %hi_79 = arith.maxsi %hi_78, %c1_i32 : i32 loc(#loc663) + %hi_80 = arith.minsi %hi, %hi_79 : i32 loc(#loc664) + %do_ptrs_81:5 = scf.for %start_m = %c0_i32 to %hi_80 step %c1_i32 iter_args(%dk_102 = %dk_49, %dv_103 = %dv, %offs_m1_104 = %offs_m1_63, %qT_ptrs_105 = %qT_ptrs_70, %do_ptrs_106 = %do_ptrs_76) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.expand_dims %offs_m1_104 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc870) + %qT_107 = tt.splat %ks0 : i32 -> tensor<1x64xi32> loc(#loc871) + %qT_108 = arith.cmpi slt, %qT, %qT_107 : tensor<1x64xi32> loc(#loc871) + %qT_109 = tt.broadcast %qT_108 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc872) + %qT_110 = tt.load %qT_ptrs_105, %qT_109, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc872) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32> loc(#loc667) + %lse_111 = arith.cmpi slt, %offs_m1_104, %lse : tensor<64xi32> loc(#loc667) + %lse_112 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc668) + %lse_113 = tt.addptr %lse_112, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc668) + %lse_114 = tt.load %lse_113, %lse_111 : tensor<64x!tt.ptr> loc(#loc669) + %lse_115 = arith.cmpf oeq, %lse_114, %cst_5 : tensor<64xf32> loc(#loc670) + %lse_116 = arith.select %lse_115, %cst_4, %lse_114 : tensor<64xi1>, tensor<64xf32> loc(#loc671) + %qkT = tt.dot %k_40, %qT_110, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc672) + %qkT_117 = arith.mulf %qkT, %cst_14 : tensor<128x64xf32> loc(#loc673) + %m = arith.remsi %qT, %qT_107 : tensor<1x64xi32> loc(#loc873) + %n = arith.remsi %ptr, %k : tensor<128x1xi32> loc(#loc874) + %post_mod_scores = arith.select %qT_109, %qkT_117, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc676) + %tmp25 = arith.cmpi slt, %m, %cst_12 : tensor<1x64xi32> loc(#loc677) + %tmp27 = tt.broadcast %n : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc678) + %tmp27_118 = tt.broadcast %m : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc678) + %tmp27_119 = arith.cmpi sle, %tmp27, %tmp27_118 : tensor<128x64xi32> loc(#loc678) + %tmp28 = tt.broadcast %tmp25 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc679) + %tmp28_120 = arith.andi %tmp28, %tmp27_119 : tensor<128x64xi1> loc(#loc679) + %tmp29 = arith.cmpi sge, %m, %cst_12 : tensor<1x64xi32> loc(#loc680) + %tmp30 = arith.cmpi slt, %n, %cst_11 : tensor<128x1xi32> loc(#loc681) + %tmp31 = tt.broadcast %tmp29 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc682) + %tmp31_121 = tt.broadcast %tmp30 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc682) + %tmp31_122 = arith.andi %tmp31, %tmp31_121 : tensor<128x64xi1> loc(#loc682) + %tmp32 = arith.extui %tmp30 : tensor<128x1xi1> to tensor<128x1xi32> loc(#loc683) + %tmp32_123 = arith.cmpi eq, %tmp32, %cst_11 : tensor<128x1xi32> loc(#loc683) + %tmp33 = tt.broadcast %tmp32_123 : tensor<128x1xi1> -> tensor<128x64xi1> loc(#loc684) + %tmp33_124 = arith.andi %tmp31, %tmp33 : tensor<128x64xi1> loc(#loc684) + %tmp36 = arith.remsi %m, %cst_6 : tensor<1x64xi32> loc(#loc685) + %tmp36_125 = arith.cmpi ne, %tmp36, %cst_12 : tensor<1x64xi32> loc(#loc686) + %tmp36_126 = arith.divsi %m, %cst_6 : tensor<1x64xi32> loc(#loc687) + %tmp36_127 = arith.subi %tmp36_126, %cst_9 : tensor<1x64xi32> loc(#loc688) + %tmp36_128 = arith.select %tmp36_125, %tmp36_127, %tmp36_126 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc689) + %tmp36_129 = arith.select %tmp25, %tmp36_128, %tmp36_126 : tensor<1x64xi1>, tensor<1x64xi32> loc(#loc690) + %tmp38 = arith.remsi %n, %cst_7 : tensor<128x1xi32> loc(#loc691) + %tmp38_130 = arith.cmpi ne, %tmp38, %cst_11 : tensor<128x1xi32> loc(#loc692) + %tmp38_131 = arith.divsi %n, %cst_7 : tensor<128x1xi32> loc(#loc693) + %tmp38_132 = arith.subi %tmp38_131, %cst_10 : tensor<128x1xi32> loc(#loc694) + %tmp38_133 = arith.select %tmp38_130, %tmp38_132, %tmp38_131 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc695) + %tmp38_134 = arith.select %tmp30, %tmp38_133, %tmp38_131 : tensor<128x1xi1>, tensor<128x1xi32> loc(#loc696) + %tmp39 = tt.broadcast %tmp36_129 : tensor<1x64xi32> -> tensor<128x64xi32> loc(#loc697) + %tmp39_135 = tt.broadcast %tmp38_134 : tensor<128x1xi32> -> tensor<128x64xi32> loc(#loc697) + %tmp39_136 = arith.cmpi eq, %tmp39, %tmp39_135 : tensor<128x64xi32> loc(#loc697) + %tmp40 = arith.andi %tmp33_124, %tmp39_136 : tensor<128x64xi1> loc(#loc698) + %tmp41 = arith.ori %tmp31_122, %tmp40 : tensor<128x64xi1> loc(#loc699) + %tmp42 = arith.ori %tmp28_120, %tmp41 : tensor<128x64xi1> loc(#loc700) + %post_mod_scores_137 = arith.select %tmp42, %post_mod_scores, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc701) + %post_mod_scores_138 = arith.mulf %post_mod_scores_137, %cst_8 : tensor<128x64xf32> loc(#loc702) + %pT = tt.expand_dims %lse_116 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc703) + %pT_139 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc704) + %pT_140 = arith.subf %post_mod_scores_138, %pT_139 : tensor<128x64xf32> loc(#loc704) + %pT_141 = math.exp2 %pT_140 : tensor<128x64xf32> loc(#loc705) + %do = tt.expand_dims %offs_m1_104 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc875) + %do_142 = tt.splat %ks0 : i32 -> tensor<64x1xi32> loc(#loc876) + %do_143 = arith.cmpi slt, %do, %do_142 : tensor<64x1xi32> loc(#loc876) + %do_144 = tt.broadcast %do_143 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc877) + %do_145 = tt.load %do_ptrs_106, %do_144, %cst_3 : tensor<64x128x!tt.ptr> loc(#loc877) + %dv_146 = arith.truncf %pT_141 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc707) + %dv_147 = tt.dot %dv_146, %do_145, %dv_103, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc708) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc709) + %Di_148 = tt.addptr %Di, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc709) + %Di_149 = tt.load %Di_148, %lse_111 : tensor<64x!tt.ptr> loc(#loc710) + %dpT = tt.trans %do_145 {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc711) + %dpT_150 = tt.dot %v, %dpT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc712) + %dsT = tt.expand_dims %Di_149 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc713) + %dsT_151 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc714) + %dsT_152 = arith.subf %dpT_150, %dsT_151 : tensor<128x64xf32> loc(#loc714) + %dsT_153 = arith.mulf %pT_141, %dsT_152 : tensor<128x64xf32> loc(#loc715) + %grad_scores = arith.select %qT_109, %dsT_153, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc716) + %dsT_154 = arith.select %tmp42, %grad_scores, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc717) + %dk_155 = arith.truncf %dsT_154 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc718) + %dk_156 = tt.trans %qT_110 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc719) + %dk_157 = tt.dot %dk_155, %dk_156, %dk_102, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc720) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc878) + %cur_block = tt.addptr %q_indices, %cur_block_idx : !tt.ptr, i32 loc(#loc879) + %cur_block_158 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc880) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc881) + %next_block_159 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_61 : i32 loc(#loc882) + %next_block_160 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc883) + %next_block_161 = tt.load %next_block_160, %next_block_159 evictionPolicy = evict_last : !tt.ptr loc(#loc884) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc885) + %needs_jump_162 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc886) + %needs_jump_163 = arith.cmpi eq, %needs_jump_162, %c0_i32 : i32 loc(#loc887) + %jump_to_block = arith.subi %next_block_161, %cur_block_158 : i32 loc(#loc888) + %jump_to_block_164 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc889) + %jump_to_block_165 = arith.subi %jump_to_block_164, %c64_i32 : i32 loc(#loc890) + %offset = arith.extui %needs_jump_163 : i1 to i32 loc(#loc891) + %offset_166 = arith.muli %jump_to_block_165, %offset : i32 loc(#loc891) + %offset_167 = arith.subi %c1_i32, %offset : i32 loc(#loc892) + %offset_168 = arith.muli %offset_167, %c64_i32 : i32 loc(#loc893) + %offset_169 = arith.addi %offset_166, %offset_168 : i32 loc(#loc894) + %qT_ptrs_170 = arith.muli %offset_169, %c4096_i32 : i32 loc(#loc722) + %qT_ptrs_171 = tt.splat %qT_ptrs_170 : i32 -> tensor<128x64xi32> loc(#loc723) + %qT_ptrs_172 = tt.addptr %qT_ptrs_105, %qT_ptrs_171 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc723) + %do_ptrs_173 = arith.muli %offset_169, %c128_i32 : i32 loc(#loc724) + %do_ptrs_174 = tt.splat %do_ptrs_173 : i32 -> tensor<64x128xi32> loc(#loc725) + %do_ptrs_175 = tt.addptr %do_ptrs_106, %do_ptrs_174 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc725) + %offs_m1_176 = tt.splat %offset_169 : i32 -> tensor<64xi32> loc(#loc726) + %offs_m1_177 = arith.addi %offs_m1_104, %offs_m1_176 : tensor<64xi32> loc(#loc726) + scf.yield %dk_157, %dv_147, %offs_m1_177, %qT_ptrs_172, %do_ptrs_175 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc580) + } loc(#loc933) + %q_indices_82 = tt.addptr %arg_FULL_Q_IDX, %pid : !tt.ptr, i32 loc(#loc581) + %q_start_83 = tt.load %q_indices_82 : !tt.ptr loc(#loc582) + %q_start_84 = arith.muli %q_start_83, %c128_i32 : i32 loc(#loc583) + %sparse_q_num_blocks_85 = tt.addptr %arg_FULL_Q_NUM_BLKS, %pid : !tt.ptr, i32 loc(#loc584) + %sparse_q_num_blocks_86 = tt.load %sparse_q_num_blocks_85 : !tt.ptr loc(#loc585) + %offs_m1_87 = tt.splat %q_start_84 : i32 -> tensor<64xi32> loc(#loc586) + %offs_m1_88 = arith.addi %offs_m1_87, %offs_m1 : tensor<64xi32> loc(#loc586) + %qT_ptrs_89 = tt.expand_dims %offs_m1_88 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc727) + %qT_ptrs_90 = arith.muli %qT_ptrs_89, %cst_0 : tensor<1x64xi32> loc(#loc728) + %qT_ptrs_91 = tt.addptr %qT_ptrs_65, %qT_ptrs_90 : tensor<1x64x!tt.ptr>, tensor<1x64xi32> loc(#loc729) + %qT_ptrs_92 = tt.broadcast %qT_ptrs_91 : tensor<1x64x!tt.ptr> -> tensor<128x64x!tt.ptr> loc(#loc730) + %qT_ptrs_93 = tt.addptr %qT_ptrs_92, %qT_ptrs_69 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc730) + %do_ptrs_94 = tt.expand_dims %offs_m1_88 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc731) + %do_ptrs_95 = arith.muli %do_ptrs_94, %cst : tensor<64x1xi32> loc(#loc732) + %do_ptrs_96 = tt.addptr %do_ptrs_72, %do_ptrs_95 : tensor<64x1x!tt.ptr>, tensor<64x1xi32> loc(#loc733) + %do_ptrs_97 = tt.broadcast %do_ptrs_96 : tensor<64x1x!tt.ptr> -> tensor<64x128x!tt.ptr> loc(#loc734) + %do_ptrs_98 = tt.addptr %do_ptrs_97, %do_ptrs_75 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc734) + %hi_99 = arith.muli %sparse_q_num_blocks_86, %c2_i32 : i32 loc(#loc735) + %hi_100 = arith.minsi %hi_99, %hi_79 : i32 loc(#loc736) + %do_ptrs_101:5 = scf.for %start_m = %c0_i32 to %hi_100 step %c1_i32 iter_args(%dk_102 = %do_ptrs_81#0, %dv_103 = %do_ptrs_81#1, %offs_m1_104 = %offs_m1_88, %qT_ptrs_105 = %qT_ptrs_93, %do_ptrs_106 = %do_ptrs_98) -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr>) : i32 { + %qT = tt.expand_dims %offs_m1_104 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> loc(#loc895) + %qT_107 = tt.splat %ks0 : i32 -> tensor<1x64xi32> loc(#loc896) + %qT_108 = arith.cmpi slt, %qT, %qT_107 : tensor<1x64xi32> loc(#loc896) + %qT_109 = tt.broadcast %qT_108 : tensor<1x64xi1> -> tensor<128x64xi1> loc(#loc897) + %qT_110 = tt.load %qT_ptrs_105, %qT_109, %cst_16 : tensor<128x64x!tt.ptr> loc(#loc897) + %lse = tt.splat %ks0 : i32 -> tensor<64xi32> loc(#loc738) + %lse_111 = arith.cmpi slt, %offs_m1_104, %lse : tensor<64xi32> loc(#loc738) + %lse_112 = tt.splat %LSE1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc739) + %lse_113 = tt.addptr %lse_112, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc739) + %lse_114 = tt.load %lse_113, %lse_111 : tensor<64x!tt.ptr> loc(#loc740) + %lse_115 = arith.cmpf oeq, %lse_114, %cst_5 : tensor<64xf32> loc(#loc741) + %lse_116 = arith.select %lse_115, %cst_4, %lse_114 : tensor<64xi1>, tensor<64xf32> loc(#loc742) + %qkT = tt.dot %k_40, %qT_110, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc743) + %qkT_117 = arith.mulf %qkT, %cst_14 : tensor<128x64xf32> loc(#loc744) + %post_mod_scores = arith.select %qT_109, %qkT_117, %cst_13 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc745) + %post_mod_scores_118 = arith.mulf %post_mod_scores, %cst_8 : tensor<128x64xf32> loc(#loc746) + %pT = tt.expand_dims %lse_116 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc747) + %pT_119 = tt.broadcast %pT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc748) + %pT_120 = arith.subf %post_mod_scores_118, %pT_119 : tensor<128x64xf32> loc(#loc748) + %pT_121 = math.exp2 %pT_120 : tensor<128x64xf32> loc(#loc749) + %do = tt.expand_dims %offs_m1_104 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> loc(#loc898) + %do_122 = tt.splat %ks0 : i32 -> tensor<64x1xi32> loc(#loc899) + %do_123 = arith.cmpi slt, %do, %do_122 : tensor<64x1xi32> loc(#loc899) + %do_124 = tt.broadcast %do_123 : tensor<64x1xi1> -> tensor<64x128xi1> loc(#loc900) + %do_125 = tt.load %do_ptrs_106, %do_124, %cst_3 : tensor<64x128x!tt.ptr> loc(#loc900) + %dv_126 = arith.truncf %pT_121 : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc751) + %dv_127 = tt.dot %dv_126, %do_125, %dv_103, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc752) + %Di = tt.splat %DELTA1 : !tt.ptr -> tensor<64x!tt.ptr> loc(#loc753) + %Di_128 = tt.addptr %Di, %offs_m1_104 : tensor<64x!tt.ptr>, tensor<64xi32> loc(#loc753) + %Di_129 = tt.load %Di_128, %lse_111 : tensor<64x!tt.ptr> loc(#loc754) + %dpT = tt.trans %do_125 {order = array} : tensor<64x128xbf16> -> tensor<128x64xbf16> loc(#loc755) + %dpT_130 = tt.dot %v, %dpT, %cst_15, inputPrecision = tf32 : tensor<128x128xbf16> * tensor<128x64xbf16> -> tensor<128x64xf32> loc(#loc756) + %dsT = tt.expand_dims %Di_129 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32> loc(#loc757) + %dsT_131 = tt.broadcast %dsT : tensor<1x64xf32> -> tensor<128x64xf32> loc(#loc758) + %dsT_132 = arith.subf %dpT_130, %dsT_131 : tensor<128x64xf32> loc(#loc758) + %dsT_133 = arith.mulf %pT_121, %dsT_132 : tensor<128x64xf32> loc(#loc759) + %grad_scores = arith.select %qT_109, %dsT_133, %cst_15 : tensor<128x64xi1>, tensor<128x64xf32> loc(#loc760) + %dk_134 = arith.truncf %grad_scores : tensor<128x64xf32> to tensor<128x64xbf16> loc(#loc761) + %dk_135 = tt.trans %qT_110 {order = array} : tensor<128x64xbf16> -> tensor<64x128xbf16> loc(#loc762) + %dk_136 = tt.dot %dk_134, %dk_135, %dk_102, inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x128xbf16> -> tensor<128x128xf32> loc(#loc763) + %cur_block_idx = arith.divsi %start_m, %c2_i32 : i32 loc(#loc901) + %cur_block = tt.addptr %q_indices_82, %cur_block_idx : !tt.ptr, i32 loc(#loc902) + %cur_block_137 = tt.load %cur_block evictionPolicy = evict_last : !tt.ptr loc(#loc903) + %next_block = arith.addi %cur_block_idx, %c1_i32 : i32 loc(#loc904) + %next_block_138 = arith.cmpi slt, %next_block, %sparse_q_num_blocks_86 : i32 loc(#loc905) + %next_block_139 = tt.addptr %cur_block, %c1_i32 : !tt.ptr, i32 loc(#loc906) + %next_block_140 = tt.load %next_block_139, %next_block_138 evictionPolicy = evict_last : !tt.ptr loc(#loc907) + %needs_jump = arith.addi %start_m, %c1_i32 : i32 loc(#loc908) + %needs_jump_141 = arith.remsi %needs_jump, %c2_i32 : i32 loc(#loc909) + %needs_jump_142 = arith.cmpi eq, %needs_jump_141, %c0_i32 : i32 loc(#loc910) + %jump_to_block = arith.subi %next_block_140, %cur_block_137 : i32 loc(#loc911) + %jump_to_block_143 = arith.muli %jump_to_block, %c128_i32 : i32 loc(#loc912) + %jump_to_block_144 = arith.subi %jump_to_block_143, %c64_i32 : i32 loc(#loc913) + %offset = arith.extui %needs_jump_142 : i1 to i32 loc(#loc914) + %offset_145 = arith.muli %jump_to_block_144, %offset : i32 loc(#loc914) + %offset_146 = arith.subi %c1_i32, %offset : i32 loc(#loc915) + %offset_147 = arith.muli %offset_146, %c64_i32 : i32 loc(#loc916) + %offset_148 = arith.addi %offset_145, %offset_147 : i32 loc(#loc917) + %qT_ptrs_149 = arith.muli %offset_148, %c4096_i32 : i32 loc(#loc765) + %qT_ptrs_150 = tt.splat %qT_ptrs_149 : i32 -> tensor<128x64xi32> loc(#loc766) + %qT_ptrs_151 = tt.addptr %qT_ptrs_105, %qT_ptrs_150 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> loc(#loc766) + %do_ptrs_152 = arith.muli %offset_148, %c128_i32 : i32 loc(#loc767) + %do_ptrs_153 = tt.splat %do_ptrs_152 : i32 -> tensor<64x128xi32> loc(#loc768) + %do_ptrs_154 = tt.addptr %do_ptrs_106, %do_ptrs_153 : tensor<64x128x!tt.ptr>, tensor<64x128xi32> loc(#loc768) + %offs_m1_155 = tt.splat %offset_148 : i32 -> tensor<64xi32> loc(#loc769) + %offs_m1_156 = arith.addi %offs_m1_104, %offs_m1_155 : tensor<64xi32> loc(#loc769) + scf.yield %dk_136, %dv_127, %offs_m1_156, %qT_ptrs_151, %do_ptrs_154 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<64xi32>, tensor<128x64x!tt.ptr>, tensor<64x128x!tt.ptr> loc(#loc588) + } loc(#loc934) + scf.yield %do_ptrs_101#1, %do_ptrs_101#0 : tensor<128x128xf32>, tensor<128x128xf32> loc(#loc292) + } loc(#loc651) + %dv_ptrs = tt.splat %DV : !tt.ptr -> tensor<128x1x!tt.ptr> loc(#loc589) + %dv_ptrs_45 = tt.addptr %dv_ptrs, %ptr_31 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> loc(#loc589) + %dv_ptrs_46 = tt.broadcast %dv_ptrs_45 : tensor<128x1x!tt.ptr> -> tensor<128x128x!tt.ptr> loc(#loc590) + %dv_ptrs_47 = tt.addptr %dv_ptrs_46, %ptr_36 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc590) + %11 = arith.cmpi slt, %ptr_34, %cst_20 : tensor<1x128xi32> loc(#loc295) + %12 = tt.broadcast %11 : tensor<1x128xi1> -> tensor<128x128xi1> loc(#loc296) + %13 = arith.andi %k_39, %12 : tensor<128x128xi1> loc(#loc296) + %14 = arith.truncf %dk#0 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc297) + tt.store %dv_ptrs_47, %14, %13 : tensor<128x128x!tt.ptr> loc(#loc297) + %dk_48 = arith.mulf %dk#1, %cst_21 : tensor<128x128xf32> loc(#loc591) + %15 = tt.splat %k_adj : i32 -> tensor<1x128xi32> loc(#loc299) + %16 = arith.addi %ptr_34, %15 : tensor<1x128xi32> loc(#loc299) + %17 = tt.broadcast %16 : tensor<1x128xi32> -> tensor<128x128xi32> loc(#loc300) + %18 = tt.broadcast %ptr_31 : tensor<128x1xi32> -> tensor<128x128xi32> loc(#loc300) + %19 = arith.addi %17, %18 : tensor<128x128xi32> loc(#loc300) + %20 = tt.splat %out_ptr0 : !tt.ptr -> tensor<128x128x!tt.ptr> loc(#loc301) + %21 = tt.addptr %20, %19 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> loc(#loc301) + %22 = arith.truncf %dk_48 : tensor<128x128xf32> to tensor<128x128xbf16> loc(#loc302) + tt.store %21, %22, %k_39 : tensor<128x128x!tt.ptr> loc(#loc302) + } loc(#loc29) + tt.return loc(#loc303) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":103:9) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":94:54) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:74) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:66) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:100) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:91) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:82) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:59) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":97:111) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":100:58) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":111:24) +#loc13 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:22) +#loc14 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":112:36) +#loc15 = loc("/workspace/hanrui/specforge/lib/python3.11/site-packages/triton/language/standard.py":41:28) +#loc16 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":113:34) +#loc17 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":115:27) +#loc18 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":116:28) +#loc19 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:25) +#loc20 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":124:59) +#loc21 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:50) +#loc22 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:37) +#loc23 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":128:61) +#loc24 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":131:9) +#loc25 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":132:9) +#loc26 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":133:10) +#loc27 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":136:26) +#loc28 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:14) +#loc29 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:7) +#loc30 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":140:24) +#loc31 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:29) +#loc32 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:54) +#loc33 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":144:44) +#loc34 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":145:35) +#loc35 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:30) +#loc36 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:52) +#loc37 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:40) +#loc38 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":158:63) +#loc39 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:32) +#loc40 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:55) +#loc41 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:42) +#loc42 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":159:66) +#loc43 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:30) +#loc44 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:35) +#loc45 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:46) +#loc46 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":161:56) +#loc47 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":163:17) +#loc48 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":164:19) +#loc49 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":167:19) +#loc50 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":168:21) +#loc51 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":169:25) +#loc52 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":174:36) +#loc53 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":175:29) +#loc54 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:27) +#loc55 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":178:107) +#loc56 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:38) +#loc57 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:20) +#loc58 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:56) +#loc59 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":789:49) +#loc60 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:52) +#loc61 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:23) +#loc62 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":179:111) +#loc63 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:58) +#loc64 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:34) +#loc65 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":188:25) +#loc66 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:33) +#loc67 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":189:26) +#loc68 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:30) +#loc69 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":190:50) +#loc70 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":191:18) +#loc71 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":195:30) +#loc72 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:27) +#loc73 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":196:41) +#loc74 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:53) +#loc75 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":197:39) +#loc76 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:42) +#loc77 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":199:29) +#loc78 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:26) +#loc79 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":207:12) +#loc80 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:37) +#loc81 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:18) +#loc82 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:56) +#loc83 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":390:49) +#loc84 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:18) +#loc85 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":391:49) +#loc86 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:43) +#loc87 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:90) +#loc88 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:101) +#loc89 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":395:63) +#loc90 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":397:28) +#loc91 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:41) +#loc92 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":458:105) +#loc93 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":405:12) +#loc94 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:52) +#loc95 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":795:23) +#loc96 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":459:19) +#loc97 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":461:14) +#loc98 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":762:21) +#loc99 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":464:46) +#loc100 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":467:46) +#loc101 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":476:79) +#loc102 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":481:22) +#loc103 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":483:23) +#loc104 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":484:22) +#loc105 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":485:23) +#loc106 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":486:22) +#loc107 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":487:22) +#loc108 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":488:24) +#loc109 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":489:23) +#loc110 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:70) +#loc111 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:79) +#loc112 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:91) +#loc113 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:99) +#loc114 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:102) +#loc115 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":492:119) +#loc116 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:70) +#loc117 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:79) +#loc118 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:91) +#loc119 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:99) +#loc120 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:102) +#loc121 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":494:119) +#loc122 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":495:25) +#loc123 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":496:24) +#loc124 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":497:23) +#loc125 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":498:23) +#loc126 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":503:69) +#loc127 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":506:27) +#loc128 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:39) +#loc129 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":507:21) +#loc130 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":510:104) +#loc131 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":512:20) +#loc132 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:22) +#loc133 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:19) +#loc134 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":513:14) +#loc135 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":520:71) +#loc136 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":531:43) +#loc137 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":533:15) +#loc138 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:30) +#loc139 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":535:21) +#loc140 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":752:33) +#loc141 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":411:64) +#loc142 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:38) +#loc143 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":753:24) +#loc144 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:109) +#loc145 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:113) +#loc146 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:55) +#loc147 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":754:25) +#loc148 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:30) +#loc149 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:35) +#loc150 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":755:60) +#loc151 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:34) +#loc152 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:48) +#loc153 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":756:63) +#loc154 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:29) +#loc155 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:47) +#loc156 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:61) +#loc157 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":757:42) +#loc158 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:28) +#loc159 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":414:19) +#loc160 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":415:19) +#loc161 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":417:19) +#loc162 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":417:8) +#loc163 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":214:39) +#loc164 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:31) +#loc165 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":215:45) +#loc166 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:62) +#loc167 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":216:43) +#loc168 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":218:33) +#loc169 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":226:16) +#loc170 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:24) +#loc171 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":231:56) +#loc172 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":232:14) +#loc173 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:87) +#loc174 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:69) +#loc175 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":236:30) +#loc176 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":252:25) +#loc177 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":253:29) +#loc178 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":256:107) +#loc179 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":257:107) +#loc180 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":262:30) +#loc181 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:32) +#loc182 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":263:51) +#loc183 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:34) +#loc184 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:56) +#loc185 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:44) +#loc186 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":266:67) +#loc187 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:36) +#loc188 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:59) +#loc189 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:46) +#loc190 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":267:70) +#loc191 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:34) +#loc192 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:39) +#loc193 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:50) +#loc194 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":269:60) +#loc195 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":271:21) +#loc196 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":272:23) +#loc197 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":275:25) +#loc198 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":276:29) +#loc199 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":286:32) +#loc200 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:30) +#loc201 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":287:43) +#loc202 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:55) +#loc203 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":288:42) +#loc204 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:45) +#loc205 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":290:32) +#loc206 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:26) +#loc207 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":298:16) +#loc208 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:37) +#loc209 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:18) +#loc210 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:56) +#loc211 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":583:49) +#loc212 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:27) +#loc213 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:38) +#loc214 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:19) +#loc215 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":584:51) +#loc216 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:42) +#loc217 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:87) +#loc218 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:98) +#loc219 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":590:61) +#loc220 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":592:28) +#loc221 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":651:105) +#loc222 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":600:12) +#loc223 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:52) +#loc224 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:28) +#loc225 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":656:22) +#loc226 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:26) +#loc227 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":657:46) +#loc228 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":658:20) +#loc229 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":660:15) +#loc230 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":662:46) +#loc231 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":665:46) +#loc232 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":674:78) +#loc233 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":679:24) +#loc234 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":681:25) +#loc235 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":682:24) +#loc236 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":683:25) +#loc237 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":684:24) +#loc238 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":685:24) +#loc239 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":686:25) +#loc240 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":687:24) +#loc241 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:70) +#loc242 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:79) +#loc243 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:91) +#loc244 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:99) +#loc245 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:102) +#loc246 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":690:119) +#loc247 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:70) +#loc248 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:79) +#loc249 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:91) +#loc250 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:99) +#loc251 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:102) +#loc252 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":692:119) +#loc253 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":693:25) +#loc254 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":694:24) +#loc255 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":695:24) +#loc256 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":696:24) +#loc257 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":700:69) +#loc258 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":703:27) +#loc259 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:44) +#loc260 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:40) +#loc261 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":704:22) +#loc262 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":797:41) +#loc263 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":705:99) +#loc264 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:24) +#loc265 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":708:43) +#loc266 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:29) +#loc267 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":712:21) +#loc268 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:29) +#loc269 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":714:20) +#loc270 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:25) +#loc271 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:22) +#loc272 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":715:16) +#loc273 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":723:70) +#loc274 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":737:45) +#loc275 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:24) +#loc276 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:52) +#loc277 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":739:43) +#loc278 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":605:62) +#loc279 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:28) +#loc280 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":608:19) +#loc281 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:28) +#loc282 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":609:19) +#loc283 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":610:19) +#loc284 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":610:8) +#loc285 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":306:41) +#loc286 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:34) +#loc287 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":307:47) +#loc288 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:64) +#loc289 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":308:46) +#loc290 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":310:36) +#loc291 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":318:20) +#loc292 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":303:12) +#loc293 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:23) +#loc294 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":323:55) +#loc295 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:71) +#loc296 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:61) +#loc297 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":332:30) +#loc298 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":334:14) +#loc299 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:55) +#loc300 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:69) +#loc301 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:29) +#loc302 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":345:99) +#loc303 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/l7/cl72q7ka2ycg3khwhfq7vugjhzdkalvz6whnqhwzpmvvuxsncdc7.py":139:4) +#loc323 = loc("HQ"(#loc2)) +#loc324 = loc("pid"(#loc12)) +#loc325 = loc("NUM_KV_BLOCKS"(#loc14)) +#loc326 = loc("NUM_Q_BLOCKS"(#loc16)) +#loc327 = loc("off_zq"(#loc17)) +#loc328 = loc("off_hkv"(#loc18)) +#loc329 = loc("k_adj"(#loc19)) +#loc330 = loc("k_adj"(#loc20)) +#loc331 = loc("dv_adj"(#loc21)) +#loc332 = loc("dv_adj"(#loc22)) +#loc333 = loc("dv_adj"(#loc23)) +#loc334 = loc("K"(#loc24)) +#loc335 = loc("V"(#loc25)) +#loc336 = loc("DV"(#loc26)) +#loc337 = loc("offs_k"(#loc27)) +#loc338 = loc("off_pid"(#loc30)) +#loc339 = loc("off_hq2"(#loc31)) +#loc340 = loc("off_hq2"(#loc32)) +#loc341 = loc("off_hq2"(#loc33)) +#loc342 = loc("start_m2_block"(#loc34)) +#loc343 = loc("q_adj2"(#loc35)) +#loc344 = loc("q_adj2"(#loc36)) +#loc345 = loc("q_adj2"(#loc37)) +#loc346 = loc("q_adj2"(#loc38)) +#loc347 = loc("do_adj2"(#loc39)) +#loc348 = loc("do_adj2"(#loc40)) +#loc349 = loc("do_adj2"(#loc41)) +#loc350 = loc("do_adj2"(#loc42)) +#loc351 = loc("off_chz2"(#loc43)) +#loc352 = loc("off_chz2"(#loc44)) +#loc353 = loc("off_chz2"(#loc45)) +#loc354 = loc("off_chz2"(#loc46)) +#loc355 = loc("Q2"(#loc47)) +#loc356 = loc("DO2"(#loc48)) +#loc357 = loc("DQ2"(#loc49)) +#loc358 = loc("LSE2"(#loc50)) +#loc359 = loc("DELTA2"(#loc51)) +#loc360 = loc("start_m2"(#loc52)) +#loc361 = loc("offs_m2"(#loc53)) +#loc362 = loc("ptr"(#loc54)) +#loc363 = loc("q"(#loc55)) +#loc364 = loc("ptr"(#loc56)) +#loc365 = loc("ptr"(#loc57)) +#loc366 = loc("ptr"(#loc58)) +#loc367 = loc("ptr"(#loc59)) +#loc368 = loc("do"(#loc62)) +#loc369 = loc("Di"(#loc63)) +#loc370 = loc("Di"(#loc64)) +#loc371 = loc("Di"(#loc65)) +#loc372 = loc("lse"(#loc66)) +#loc373 = loc("lse"(#loc67)) +#loc374 = loc("lse"(#loc68)) +#loc375 = loc("lse"(#loc69)) +#loc376 = loc("lse"(#loc70)) +#loc377 = loc("kv_indices"(#loc71)) +#loc378 = loc("kv_start"(#loc72)) +#loc379 = loc("kv_start"(#loc73)) +#loc380 = loc("sparse_kv_num_blocks"(#loc74)) +#loc381 = loc("sparse_kv_num_blocks"(#loc75)) +#loc382 = loc("offs_n2"(#loc76)) +#loc383 = loc("offs_n2"(#loc77)) +#loc384 = loc("kT_ptrs"(#loc78)) +#loc385 = loc("dq"(#loc79)) +#loc386 = loc("kT_ptrs"(#loc80)) +#loc387 = loc("kT_ptrs"(#loc81)) +#loc388 = loc("kT_ptrs"(#loc82)) +#loc389 = loc("kT_ptrs"(#loc83)) +#loc390 = loc("vT_ptrs"(#loc84)) +#loc391 = loc("vT_ptrs"(#loc85)) +#loc392 = loc("hi"(#loc86)) +#loc393 = loc("hi"(#loc87)) +#loc394 = loc("hi"(#loc88)) +#loc395 = loc("hi"(#loc89)) +#loc396 = loc("dq"(#loc90)) +#loc397 = loc("kT"(#loc92)) +#loc398 = loc("dq"(#loc93)) +#loc399 = loc("qk"(#loc96)) +#loc400 = loc("qk"(#loc97)) +#loc401 = loc("n"(#loc99)) +#loc402 = loc("m"(#loc100)) +#loc403 = loc("post_mod_scores"(#loc101)) +#loc404 = loc("tmp3"(#loc102)) +#loc405 = loc("tmp5"(#loc103)) +#loc406 = loc("tmp6"(#loc104)) +#loc407 = loc("tmp7"(#loc105)) +#loc408 = loc("tmp8"(#loc106)) +#loc409 = loc("tmp9"(#loc107)) +#loc410 = loc("tmp10"(#loc108)) +#loc411 = loc("tmp11"(#loc109)) +#loc412 = loc("tmp14"(#loc110)) +#loc413 = loc("tmp14"(#loc111)) +#loc414 = loc("tmp14"(#loc112)) +#loc415 = loc("tmp14"(#loc113)) +#loc416 = loc("tmp14"(#loc114)) +#loc417 = loc("tmp14"(#loc115)) +#loc418 = loc("tmp16"(#loc116)) +#loc419 = loc("tmp16"(#loc117)) +#loc420 = loc("tmp16"(#loc118)) +#loc421 = loc("tmp16"(#loc119)) +#loc422 = loc("tmp16"(#loc120)) +#loc423 = loc("tmp16"(#loc121)) +#loc424 = loc("tmp17"(#loc122)) +#loc425 = loc("tmp18"(#loc123)) +#loc426 = loc("tmp19"(#loc124)) +#loc427 = loc("tmp20"(#loc125)) +#loc428 = loc("post_mod_scores"(#loc126)) +#loc429 = loc("post_mod_scores"(#loc127)) +#loc430 = loc("p"(#loc128)) +#loc431 = loc("p"(#loc129)) +#loc432 = loc("vT"(#loc130)) +#loc433 = loc("dp"(#loc131)) +#loc434 = loc("ds"(#loc132)) +#loc435 = loc("ds"(#loc133)) +#loc436 = loc("ds"(#loc134)) +#loc437 = loc("grad_scores"(#loc135)) +#loc438 = loc("ds"(#loc136)) +#loc439 = loc("ds"(#loc137)) +#loc440 = loc("dq"(#loc138)) +#loc441 = loc("dq"(#loc139)) +#loc442 = loc("cur_block_idx"(#loc140)) +#loc443 = loc("offset"(#loc141)) +#loc444 = loc("cur_block"(#loc142)) +#loc445 = loc("cur_block"(#loc143)) +#loc446 = loc("next_block"(#loc144)) +#loc447 = loc("next_block"(#loc145)) +#loc448 = loc("next_block"(#loc146)) +#loc449 = loc("next_block"(#loc147)) +#loc450 = loc("needs_jump"(#loc148)) +#loc451 = loc("needs_jump"(#loc149)) +#loc452 = loc("needs_jump"(#loc150)) +#loc453 = loc("jump_to_block"(#loc151)) +#loc454 = loc("jump_to_block"(#loc152)) +#loc455 = loc("jump_to_block"(#loc153)) +#loc456 = loc("offset"(#loc154)) +#loc457 = loc("offset"(#loc155)) +#loc458 = loc("offset"(#loc156)) +#loc459 = loc("offset"(#loc157)) +#loc460 = loc("kT_ptrs"(#loc158)) +#loc461 = loc("kT_ptrs"(#loc159)) +#loc462 = loc("vT_ptrs"(#loc160)) +#loc463 = loc("offs_n2"(#loc161)) +#loc464 = loc("kv_indices"(#loc163)) +#loc465 = loc("kv_start"(#loc164)) +#loc466 = loc("kv_start"(#loc165)) +#loc467 = loc("sparse_kv_num_blocks"(#loc166)) +#loc468 = loc("sparse_kv_num_blocks"(#loc167)) +#loc469 = loc("offs_n2"(#loc168)) +#loc470 = loc("dq"(#loc169)) +#loc471 = loc("dq_ptrs"(#loc170)) +#loc472 = loc("dq_ptrs"(#loc171)) +#loc473 = loc("dq"(#loc172)) +#loc474 = loc("start_n1"(#loc176)) +#loc475 = loc("offs_n1"(#loc177)) +#loc476 = loc("k"(#loc178)) +#loc477 = loc("v"(#loc179)) +#loc478 = loc("dv"(#loc180)) +#loc479 = loc("off_hq1"(#loc181)) +#loc480 = loc("off_hq1"(#loc182)) +#loc481 = loc("q_adj1"(#loc183)) +#loc482 = loc("q_adj1"(#loc184)) +#loc483 = loc("q_adj1"(#loc185)) +#loc484 = loc("q_adj1"(#loc186)) +#loc485 = loc("do_adj1"(#loc187)) +#loc486 = loc("do_adj1"(#loc188)) +#loc487 = loc("do_adj1"(#loc189)) +#loc488 = loc("do_adj1"(#loc190)) +#loc489 = loc("off_chz1"(#loc191)) +#loc490 = loc("off_chz1"(#loc192)) +#loc491 = loc("off_chz1"(#loc193)) +#loc492 = loc("off_chz1"(#loc194)) +#loc493 = loc("Q1"(#loc195)) +#loc494 = loc("DO1"(#loc196)) +#loc495 = loc("LSE1"(#loc197)) +#loc496 = loc("DELTA1"(#loc198)) +#loc497 = loc("q_indices"(#loc199)) +#loc498 = loc("q_start"(#loc200)) +#loc499 = loc("q_start"(#loc201)) +#loc500 = loc("sparse_q_num_blocks"(#loc202)) +#loc501 = loc("sparse_q_num_blocks"(#loc203)) +#loc502 = loc("offs_m1"(#loc204)) +#loc503 = loc("offs_m1"(#loc205)) +#loc504 = loc("qT_ptrs"(#loc206)) +#loc505 = loc("qT_ptrs"(#loc208)) +#loc506 = loc("qT_ptrs"(#loc209)) +#loc507 = loc("qT_ptrs"(#loc210)) +#loc508 = loc("qT_ptrs"(#loc211)) +#loc509 = loc("do_ptrs"(#loc212)) +#loc510 = loc("do_ptrs"(#loc213)) +#loc511 = loc("do_ptrs"(#loc214)) +#loc512 = loc("do_ptrs"(#loc215)) +#loc513 = loc("hi"(#loc216)) +#loc514 = loc("hi"(#loc217)) +#loc515 = loc("hi"(#loc218)) +#loc516 = loc("hi"(#loc219)) +#loc517 = loc("dk"(#loc220)) +#loc518 = loc("qT"(#loc221)) +#loc519 = loc(callsite(#loc222 at #loc207)) +#loc520 = loc("lse"(#loc223)) +#loc521 = loc("lse"(#loc224)) +#loc522 = loc("lse"(#loc225)) +#loc523 = loc("lse"(#loc226)) +#loc524 = loc("lse"(#loc227)) +#loc525 = loc("qkT"(#loc228)) +#loc526 = loc("qkT"(#loc229)) +#loc527 = loc("m"(#loc230)) +#loc528 = loc("n"(#loc231)) +#loc529 = loc("post_mod_scores"(#loc232)) +#loc530 = loc("tmp25"(#loc233)) +#loc531 = loc("tmp27"(#loc234)) +#loc532 = loc("tmp28"(#loc235)) +#loc533 = loc("tmp29"(#loc236)) +#loc534 = loc("tmp30"(#loc237)) +#loc535 = loc("tmp31"(#loc238)) +#loc536 = loc("tmp32"(#loc239)) +#loc537 = loc("tmp33"(#loc240)) +#loc538 = loc("tmp36"(#loc241)) +#loc539 = loc("tmp36"(#loc242)) +#loc540 = loc("tmp36"(#loc243)) +#loc541 = loc("tmp36"(#loc244)) +#loc542 = loc("tmp36"(#loc245)) +#loc543 = loc("tmp36"(#loc246)) +#loc544 = loc("tmp38"(#loc247)) +#loc545 = loc("tmp38"(#loc248)) +#loc546 = loc("tmp38"(#loc249)) +#loc547 = loc("tmp38"(#loc250)) +#loc548 = loc("tmp38"(#loc251)) +#loc549 = loc("tmp38"(#loc252)) +#loc550 = loc("tmp39"(#loc253)) +#loc551 = loc("tmp40"(#loc254)) +#loc552 = loc("tmp41"(#loc255)) +#loc553 = loc("tmp42"(#loc256)) +#loc554 = loc("post_mod_scores"(#loc257)) +#loc555 = loc("post_mod_scores"(#loc258)) +#loc556 = loc("pT"(#loc259)) +#loc557 = loc("pT"(#loc260)) +#loc558 = loc("pT"(#loc261)) +#loc559 = loc("do"(#loc263)) +#loc560 = loc("dv"(#loc264)) +#loc561 = loc("dv"(#loc265)) +#loc562 = loc("Di"(#loc266)) +#loc563 = loc("Di"(#loc267)) +#loc564 = loc("dpT"(#loc268)) +#loc565 = loc("dpT"(#loc269)) +#loc566 = loc("dsT"(#loc270)) +#loc567 = loc("dsT"(#loc271)) +#loc568 = loc("dsT"(#loc272)) +#loc569 = loc("grad_scores"(#loc273)) +#loc570 = loc("dsT"(#loc274)) +#loc571 = loc("dk"(#loc275)) +#loc572 = loc("dk"(#loc276)) +#loc573 = loc("dk"(#loc277)) +#loc574 = loc("offset"(#loc278)) +#loc575 = loc("qT_ptrs"(#loc279)) +#loc576 = loc("qT_ptrs"(#loc280)) +#loc577 = loc("do_ptrs"(#loc281)) +#loc578 = loc("do_ptrs"(#loc282)) +#loc579 = loc("offs_m1"(#loc283)) +#loc580 = loc(callsite(#loc284 at #loc207)) +#loc581 = loc("q_indices"(#loc285)) +#loc582 = loc("q_start"(#loc286)) +#loc583 = loc("q_start"(#loc287)) +#loc584 = loc("sparse_q_num_blocks"(#loc288)) +#loc585 = loc("sparse_q_num_blocks"(#loc289)) +#loc586 = loc("offs_m1"(#loc290)) +#loc587 = loc(callsite(#loc222 at #loc291)) +#loc588 = loc(callsite(#loc284 at #loc291)) +#loc589 = loc("dv_ptrs"(#loc293)) +#loc590 = loc("dv_ptrs"(#loc294)) +#loc591 = loc("dk"(#loc298)) +#loc592 = loc(callsite(#loc13 at #loc325)) +#loc593 = loc(callsite(#loc15 at #loc325)) +#loc594 = loc(callsite(#loc13 at #loc326)) +#loc595 = loc(callsite(#loc15 at #loc326)) +#loc596 = loc(callsite(#loc362 at #loc363)) +#loc597 = loc(callsite(#loc364 at #loc363)) +#loc598 = loc(callsite(#loc365 at #loc363)) +#loc599 = loc(callsite(#loc366 at #loc363)) +#loc600 = loc(callsite(#loc367 at #loc363)) +#loc601 = loc(callsite(#loc60 at #loc363)) +#loc602 = loc(callsite(#loc61 at #loc363)) +#loc603 = loc(callsite(#loc364 at #loc368)) +#loc604 = loc(callsite(#loc365 at #loc368)) +#loc605 = loc(callsite(#loc367 at #loc368)) +#loc606 = loc(callsite(#loc61 at #loc368)) +#loc607 = loc(callsite(#loc384 at #loc385)) +#loc608 = loc(callsite(#loc386 at #loc385)) +#loc609 = loc(callsite(#loc387 at #loc385)) +#loc610 = loc(callsite(#loc388 at #loc385)) +#loc611 = loc(callsite(#loc389 at #loc385)) +#loc612 = loc(callsite(#loc390 at #loc385)) +#loc613 = loc(callsite(#loc391 at #loc385)) +#loc614 = loc(callsite(#loc392 at #loc385)) +#loc615 = loc(callsite(#loc393 at #loc385)) +#loc616 = loc(callsite(#loc394 at #loc385)) +#loc617 = loc(callsite(#loc395 at #loc385)) +#loc618 = loc("offs_n2"(#loc396)) +#loc619 = loc(callsite(#loc398 at #loc385)) +#loc620 = loc(callsite(#loc443 at #loc385)) +#loc621 = loc(callsite(#loc460 at #loc385)) +#loc622 = loc(callsite(#loc461 at #loc385)) +#loc623 = loc(callsite(#loc462 at #loc385)) +#loc624 = loc(callsite(#loc463 at #loc385)) +#loc625 = loc(callsite(#loc162 at #loc385)) +#loc626 = loc(callsite(#loc384 at #loc470)) +#loc627 = loc(callsite(#loc386 at #loc470)) +#loc628 = loc(callsite(#loc387 at #loc470)) +#loc629 = loc(callsite(#loc389 at #loc470)) +#loc630 = loc(callsite(#loc390 at #loc470)) +#loc631 = loc(callsite(#loc391 at #loc470)) +#loc632 = loc(callsite(#loc392 at #loc470)) +#loc633 = loc(callsite(#loc395 at #loc470)) +#loc634 = loc(callsite(#loc398 at #loc470)) +#loc635 = loc(callsite(#loc443 at #loc470)) +#loc636 = loc(callsite(#loc460 at #loc470)) +#loc637 = loc(callsite(#loc461 at #loc470)) +#loc638 = loc(callsite(#loc462 at #loc470)) +#loc639 = loc(callsite(#loc463 at #loc470)) +#loc640 = loc(callsite(#loc162 at #loc470)) +#loc641 = loc(callsite(#loc362 at #loc476)) +#loc642 = loc(callsite(#loc364 at #loc476)) +#loc643 = loc(callsite(#loc365 at #loc476)) +#loc644 = loc(callsite(#loc366 at #loc476)) +#loc645 = loc(callsite(#loc367 at #loc476)) +#loc646 = loc(callsite(#loc60 at #loc476)) +#loc647 = loc(callsite(#loc61 at #loc476)) +#loc648 = loc(callsite(#loc365 at #loc477)) +#loc649 = loc(callsite(#loc367 at #loc477)) +#loc650 = loc(callsite(#loc61 at #loc477)) +#loc651 = loc("dk"(#loc478)) +#loc652 = loc(callsite(#loc504 at #loc207)) +#loc653 = loc(callsite(#loc505 at #loc207)) +#loc654 = loc(callsite(#loc506 at #loc207)) +#loc655 = loc(callsite(#loc507 at #loc207)) +#loc656 = loc(callsite(#loc508 at #loc207)) +#loc657 = loc(callsite(#loc509 at #loc207)) +#loc658 = loc(callsite(#loc510 at #loc207)) +#loc659 = loc(callsite(#loc511 at #loc207)) +#loc660 = loc(callsite(#loc512 at #loc207)) +#loc661 = loc(callsite(#loc513 at #loc207)) +#loc662 = loc(callsite(#loc514 at #loc207)) +#loc663 = loc(callsite(#loc515 at #loc207)) +#loc664 = loc(callsite(#loc516 at #loc207)) +#loc665 = loc("dv"(#loc517)) +#loc666 = loc(callsite(#loc518 at #loc519)) +#loc667 = loc(callsite(#loc520 at #loc519)) +#loc668 = loc(callsite(#loc521 at #loc519)) +#loc669 = loc(callsite(#loc522 at #loc519)) +#loc670 = loc(callsite(#loc523 at #loc519)) +#loc671 = loc(callsite(#loc524 at #loc519)) +#loc672 = loc(callsite(#loc525 at #loc519)) +#loc673 = loc(callsite(#loc526 at #loc519)) +#loc674 = loc(callsite(#loc527 at #loc519)) +#loc675 = loc(callsite(#loc528 at #loc519)) +#loc676 = loc(callsite(#loc529 at #loc519)) +#loc677 = loc(callsite(#loc530 at #loc519)) +#loc678 = loc(callsite(#loc531 at #loc519)) +#loc679 = loc(callsite(#loc532 at #loc519)) +#loc680 = loc(callsite(#loc533 at #loc519)) +#loc681 = loc(callsite(#loc534 at #loc519)) +#loc682 = loc(callsite(#loc535 at #loc519)) +#loc683 = loc(callsite(#loc536 at #loc519)) +#loc684 = loc(callsite(#loc537 at #loc519)) +#loc685 = loc(callsite(#loc538 at #loc519)) +#loc686 = loc(callsite(#loc539 at #loc519)) +#loc687 = loc(callsite(#loc540 at #loc519)) +#loc688 = loc(callsite(#loc541 at #loc519)) +#loc689 = loc(callsite(#loc542 at #loc519)) +#loc690 = loc(callsite(#loc543 at #loc519)) +#loc691 = loc(callsite(#loc544 at #loc519)) +#loc692 = loc(callsite(#loc545 at #loc519)) +#loc693 = loc(callsite(#loc546 at #loc519)) +#loc694 = loc(callsite(#loc547 at #loc519)) +#loc695 = loc(callsite(#loc548 at #loc519)) +#loc696 = loc(callsite(#loc549 at #loc519)) +#loc697 = loc(callsite(#loc550 at #loc519)) +#loc698 = loc(callsite(#loc551 at #loc519)) +#loc699 = loc(callsite(#loc552 at #loc519)) +#loc700 = loc(callsite(#loc553 at #loc519)) +#loc701 = loc(callsite(#loc554 at #loc519)) +#loc702 = loc(callsite(#loc555 at #loc519)) +#loc703 = loc(callsite(#loc556 at #loc519)) +#loc704 = loc(callsite(#loc557 at #loc519)) +#loc705 = loc(callsite(#loc558 at #loc519)) +#loc706 = loc(callsite(#loc559 at #loc519)) +#loc707 = loc(callsite(#loc560 at #loc519)) +#loc708 = loc(callsite(#loc561 at #loc519)) +#loc709 = loc(callsite(#loc562 at #loc519)) +#loc710 = loc(callsite(#loc563 at #loc519)) +#loc711 = loc(callsite(#loc564 at #loc519)) +#loc712 = loc(callsite(#loc565 at #loc519)) +#loc713 = loc(callsite(#loc566 at #loc519)) +#loc714 = loc(callsite(#loc567 at #loc519)) +#loc715 = loc(callsite(#loc568 at #loc519)) +#loc716 = loc(callsite(#loc569 at #loc519)) +#loc717 = loc(callsite(#loc570 at #loc519)) +#loc718 = loc(callsite(#loc571 at #loc519)) +#loc719 = loc(callsite(#loc572 at #loc519)) +#loc720 = loc(callsite(#loc573 at #loc519)) +#loc721 = loc(callsite(#loc574 at #loc207)) +#loc722 = loc(callsite(#loc575 at #loc207)) +#loc723 = loc(callsite(#loc576 at #loc207)) +#loc724 = loc(callsite(#loc577 at #loc207)) +#loc725 = loc(callsite(#loc578 at #loc207)) +#loc726 = loc(callsite(#loc579 at #loc207)) +#loc727 = loc(callsite(#loc504 at #loc291)) +#loc728 = loc(callsite(#loc505 at #loc291)) +#loc729 = loc(callsite(#loc506 at #loc291)) +#loc730 = loc(callsite(#loc508 at #loc291)) +#loc731 = loc(callsite(#loc509 at #loc291)) +#loc732 = loc(callsite(#loc510 at #loc291)) +#loc733 = loc(callsite(#loc511 at #loc291)) +#loc734 = loc(callsite(#loc512 at #loc291)) +#loc735 = loc(callsite(#loc513 at #loc291)) +#loc736 = loc(callsite(#loc516 at #loc291)) +#loc737 = loc(callsite(#loc518 at #loc587)) +#loc738 = loc(callsite(#loc520 at #loc587)) +#loc739 = loc(callsite(#loc521 at #loc587)) +#loc740 = loc(callsite(#loc522 at #loc587)) +#loc741 = loc(callsite(#loc523 at #loc587)) +#loc742 = loc(callsite(#loc524 at #loc587)) +#loc743 = loc(callsite(#loc525 at #loc587)) +#loc744 = loc(callsite(#loc526 at #loc587)) +#loc745 = loc(callsite(#loc529 at #loc587)) +#loc746 = loc(callsite(#loc555 at #loc587)) +#loc747 = loc(callsite(#loc556 at #loc587)) +#loc748 = loc(callsite(#loc557 at #loc587)) +#loc749 = loc(callsite(#loc558 at #loc587)) +#loc750 = loc(callsite(#loc559 at #loc587)) +#loc751 = loc(callsite(#loc560 at #loc587)) +#loc752 = loc(callsite(#loc561 at #loc587)) +#loc753 = loc(callsite(#loc562 at #loc587)) +#loc754 = loc(callsite(#loc563 at #loc587)) +#loc755 = loc(callsite(#loc564 at #loc587)) +#loc756 = loc(callsite(#loc565 at #loc587)) +#loc757 = loc(callsite(#loc566 at #loc587)) +#loc758 = loc(callsite(#loc567 at #loc587)) +#loc759 = loc(callsite(#loc568 at #loc587)) +#loc760 = loc(callsite(#loc569 at #loc587)) +#loc761 = loc(callsite(#loc571 at #loc587)) +#loc762 = loc(callsite(#loc572 at #loc587)) +#loc763 = loc(callsite(#loc573 at #loc587)) +#loc764 = loc(callsite(#loc574 at #loc291)) +#loc765 = loc(callsite(#loc575 at #loc291)) +#loc766 = loc(callsite(#loc576 at #loc291)) +#loc767 = loc(callsite(#loc577 at #loc291)) +#loc768 = loc(callsite(#loc578 at #loc291)) +#loc769 = loc(callsite(#loc579 at #loc291)) +#loc770 = loc(callsite(#loc13 at #loc615)) +#loc771 = loc(callsite(#loc15 at #loc615)) +#loc772 = loc("kT_ptrs"(#loc618)) +#loc773 = loc(callsite(#loc397 at #loc619)) +#loc774 = loc(callsite(#loc399 at #loc619)) +#loc775 = loc(callsite(#loc400 at #loc619)) +#loc776 = loc(callsite(#loc401 at #loc619)) +#loc777 = loc(callsite(#loc402 at #loc619)) +#loc778 = loc(callsite(#loc403 at #loc619)) +#loc779 = loc(callsite(#loc404 at #loc619)) +#loc780 = loc(callsite(#loc405 at #loc619)) +#loc781 = loc(callsite(#loc406 at #loc619)) +#loc782 = loc(callsite(#loc407 at #loc619)) +#loc783 = loc(callsite(#loc408 at #loc619)) +#loc784 = loc(callsite(#loc409 at #loc619)) +#loc785 = loc(callsite(#loc410 at #loc619)) +#loc786 = loc(callsite(#loc411 at #loc619)) +#loc787 = loc(callsite(#loc412 at #loc619)) +#loc788 = loc(callsite(#loc413 at #loc619)) +#loc789 = loc(callsite(#loc414 at #loc619)) +#loc790 = loc(callsite(#loc415 at #loc619)) +#loc791 = loc(callsite(#loc416 at #loc619)) +#loc792 = loc(callsite(#loc417 at #loc619)) +#loc793 = loc(callsite(#loc418 at #loc619)) +#loc794 = loc(callsite(#loc419 at #loc619)) +#loc795 = loc(callsite(#loc420 at #loc619)) +#loc796 = loc(callsite(#loc421 at #loc619)) +#loc797 = loc(callsite(#loc422 at #loc619)) +#loc798 = loc(callsite(#loc423 at #loc619)) +#loc799 = loc(callsite(#loc424 at #loc619)) +#loc800 = loc(callsite(#loc425 at #loc619)) +#loc801 = loc(callsite(#loc426 at #loc619)) +#loc802 = loc(callsite(#loc427 at #loc619)) +#loc803 = loc(callsite(#loc428 at #loc619)) +#loc804 = loc(callsite(#loc429 at #loc619)) +#loc805 = loc(callsite(#loc430 at #loc619)) +#loc806 = loc(callsite(#loc431 at #loc619)) +#loc807 = loc(callsite(#loc432 at #loc619)) +#loc808 = loc(callsite(#loc433 at #loc619)) +#loc809 = loc(callsite(#loc434 at #loc619)) +#loc810 = loc(callsite(#loc435 at #loc619)) +#loc811 = loc(callsite(#loc436 at #loc619)) +#loc812 = loc(callsite(#loc437 at #loc619)) +#loc813 = loc(callsite(#loc438 at #loc619)) +#loc814 = loc(callsite(#loc439 at #loc619)) +#loc815 = loc(callsite(#loc440 at #loc619)) +#loc816 = loc(callsite(#loc441 at #loc619)) +#loc817 = loc(callsite(#loc442 at #loc620)) +#loc818 = loc(callsite(#loc444 at #loc620)) +#loc819 = loc(callsite(#loc445 at #loc620)) +#loc820 = loc(callsite(#loc446 at #loc620)) +#loc821 = loc(callsite(#loc447 at #loc620)) +#loc822 = loc(callsite(#loc448 at #loc620)) +#loc823 = loc(callsite(#loc449 at #loc620)) +#loc824 = loc(callsite(#loc450 at #loc620)) +#loc825 = loc(callsite(#loc451 at #loc620)) +#loc826 = loc(callsite(#loc452 at #loc620)) +#loc827 = loc(callsite(#loc453 at #loc620)) +#loc828 = loc(callsite(#loc454 at #loc620)) +#loc829 = loc(callsite(#loc455 at #loc620)) +#loc830 = loc(callsite(#loc456 at #loc620)) +#loc831 = loc(callsite(#loc457 at #loc620)) +#loc832 = loc(callsite(#loc458 at #loc620)) +#loc833 = loc(callsite(#loc459 at #loc620)) +#loc834 = loc(callsite(#loc397 at #loc634)) +#loc835 = loc(callsite(#loc399 at #loc634)) +#loc836 = loc(callsite(#loc400 at #loc634)) +#loc837 = loc(callsite(#loc403 at #loc634)) +#loc838 = loc(callsite(#loc429 at #loc634)) +#loc839 = loc(callsite(#loc430 at #loc634)) +#loc840 = loc(callsite(#loc431 at #loc634)) +#loc841 = loc(callsite(#loc432 at #loc634)) +#loc842 = loc(callsite(#loc433 at #loc634)) +#loc843 = loc(callsite(#loc434 at #loc634)) +#loc844 = loc(callsite(#loc435 at #loc634)) +#loc845 = loc(callsite(#loc436 at #loc634)) +#loc846 = loc(callsite(#loc437 at #loc634)) +#loc847 = loc(callsite(#loc439 at #loc634)) +#loc848 = loc(callsite(#loc440 at #loc634)) +#loc849 = loc(callsite(#loc441 at #loc634)) +#loc850 = loc(callsite(#loc442 at #loc635)) +#loc851 = loc(callsite(#loc444 at #loc635)) +#loc852 = loc(callsite(#loc445 at #loc635)) +#loc853 = loc(callsite(#loc446 at #loc635)) +#loc854 = loc(callsite(#loc447 at #loc635)) +#loc855 = loc(callsite(#loc448 at #loc635)) +#loc856 = loc(callsite(#loc449 at #loc635)) +#loc857 = loc(callsite(#loc450 at #loc635)) +#loc858 = loc(callsite(#loc451 at #loc635)) +#loc859 = loc(callsite(#loc452 at #loc635)) +#loc860 = loc(callsite(#loc453 at #loc635)) +#loc861 = loc(callsite(#loc454 at #loc635)) +#loc862 = loc(callsite(#loc455 at #loc635)) +#loc863 = loc(callsite(#loc456 at #loc635)) +#loc864 = loc(callsite(#loc457 at #loc635)) +#loc865 = loc(callsite(#loc458 at #loc635)) +#loc866 = loc(callsite(#loc459 at #loc635)) +#loc867 = loc(callsite(#loc13 at #loc662)) +#loc868 = loc(callsite(#loc15 at #loc662)) +#loc869 = loc("offs_m1"(#loc665)) +#loc870 = loc(callsite(#loc91 at #loc666)) +#loc871 = loc(callsite(#loc94 at #loc666)) +#loc872 = loc(callsite(#loc95 at #loc666)) +#loc873 = loc(callsite(#loc98 at #loc674)) +#loc874 = loc(callsite(#loc98 at #loc675)) +#loc875 = loc(callsite(#loc262 at #loc706)) +#loc876 = loc(callsite(#loc60 at #loc706)) +#loc877 = loc(callsite(#loc61 at #loc706)) +#loc878 = loc(callsite(#loc442 at #loc721)) +#loc879 = loc(callsite(#loc444 at #loc721)) +#loc880 = loc(callsite(#loc445 at #loc721)) +#loc881 = loc(callsite(#loc446 at #loc721)) +#loc882 = loc(callsite(#loc447 at #loc721)) +#loc883 = loc(callsite(#loc448 at #loc721)) +#loc884 = loc(callsite(#loc449 at #loc721)) +#loc885 = loc(callsite(#loc450 at #loc721)) +#loc886 = loc(callsite(#loc451 at #loc721)) +#loc887 = loc(callsite(#loc452 at #loc721)) +#loc888 = loc(callsite(#loc453 at #loc721)) +#loc889 = loc(callsite(#loc454 at #loc721)) +#loc890 = loc(callsite(#loc455 at #loc721)) +#loc891 = loc(callsite(#loc456 at #loc721)) +#loc892 = loc(callsite(#loc457 at #loc721)) +#loc893 = loc(callsite(#loc458 at #loc721)) +#loc894 = loc(callsite(#loc459 at #loc721)) +#loc895 = loc(callsite(#loc91 at #loc737)) +#loc896 = loc(callsite(#loc94 at #loc737)) +#loc897 = loc(callsite(#loc95 at #loc737)) +#loc898 = loc(callsite(#loc262 at #loc750)) +#loc899 = loc(callsite(#loc60 at #loc750)) +#loc900 = loc(callsite(#loc61 at #loc750)) +#loc901 = loc(callsite(#loc442 at #loc764)) +#loc902 = loc(callsite(#loc444 at #loc764)) +#loc903 = loc(callsite(#loc445 at #loc764)) +#loc904 = loc(callsite(#loc446 at #loc764)) +#loc905 = loc(callsite(#loc447 at #loc764)) +#loc906 = loc(callsite(#loc448 at #loc764)) +#loc907 = loc(callsite(#loc449 at #loc764)) +#loc908 = loc(callsite(#loc450 at #loc764)) +#loc909 = loc(callsite(#loc451 at #loc764)) +#loc910 = loc(callsite(#loc452 at #loc764)) +#loc911 = loc(callsite(#loc453 at #loc764)) +#loc912 = loc(callsite(#loc454 at #loc764)) +#loc913 = loc(callsite(#loc455 at #loc764)) +#loc914 = loc(callsite(#loc456 at #loc764)) +#loc915 = loc(callsite(#loc457 at #loc764)) +#loc916 = loc(callsite(#loc458 at #loc764)) +#loc917 = loc(callsite(#loc459 at #loc764)) +#loc918 = loc("vT_ptrs"(#loc772)) +#loc919 = loc(callsite(#loc91 at #loc773)) +#loc920 = loc(callsite(#loc94 at #loc773)) +#loc921 = loc(callsite(#loc95 at #loc773)) +#loc922 = loc(callsite(#loc98 at #loc776)) +#loc923 = loc(callsite(#loc98 at #loc777)) +#loc924 = loc(callsite(#loc95 at #loc807)) +#loc925 = loc(callsite(#loc91 at #loc834)) +#loc926 = loc(callsite(#loc94 at #loc834)) +#loc927 = loc(callsite(#loc95 at #loc834)) +#loc928 = loc(callsite(#loc95 at #loc841)) +#loc929 = loc("qT_ptrs"(#loc869)) +#loc930 = loc(callsite(#loc918 at #loc385)) +#loc931 = loc(callsite(#loc918 at #loc470)) +#loc932 = loc("do_ptrs"(#loc929)) +#loc933 = loc(callsite(#loc932 at #loc207)) +#loc934 = loc(callsite(#loc932 at #loc291)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.cubin b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.cubin new file mode 100644 index 0000000000000000000000000000000000000000..4ad9fe7b478b2653e078ae3b174be20687daa766 Binary files /dev/null and b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.cubin differ diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ptx b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ptx new file mode 100644 index 0000000000000000000000000000000000000000..c6aace71f6efc84ebb34d6f87480380dcd231abb --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ptx @@ -0,0 +1,224 @@ +// +// Generated by LLVM NVPTX Back-End +// + +.version 8.7 +.target sm_90a +.address_size 64 + + // .globl triton_poi_fused_mul_1 // -- Begin function triton_poi_fused_mul_1 + // @triton_poi_fused_mul_1 +.visible .entry triton_poi_fused_mul_1( + .param .u64 .ptr .global .align 1 triton_poi_fused_mul_1_param_0, + .param .u64 .ptr .global .align 1 triton_poi_fused_mul_1_param_1, + .param .u32 triton_poi_fused_mul_1_param_2, + .param .u64 .ptr .global .align 1 triton_poi_fused_mul_1_param_3, + .param .u64 .ptr .global .align 1 triton_poi_fused_mul_1_param_4 +) +.reqntid 128 +{ + .reg .pred %p<3>; + .reg .b32 %r<11>; + .reg .b64 %rd<6>; + .loc 1 18 0 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:18:0 +$L__func_begin0: + .loc 1 18 0 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:18:0 + +// %bb.0: + ld.param.b64 %rd3, [triton_poi_fused_mul_1_param_0]; + ld.param.b64 %rd4, [triton_poi_fused_mul_1_param_1]; +$L__tmp0: + .loc 1 20 28 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:20:28 + mov.u32 %r5, %ctaid.x; + .loc 1 20 33 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:20:33 + shl.b32 %r6, %r5, 8; + .loc 1 21 36 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:21:36 + mov.u32 %r7, %tid.x; + shl.b32 %r8, %r7, 1; + and.b32 %r9, %r8, 254; + .loc 1 21 23 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:21:23 + or.b32 %r10, %r9, %r6; + .loc 1 22 21 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:22:21 + setp.lt.s32 %p1, %r10, 31232; + .loc 1 24 30 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:24:30 + mul.wide.s32 %rd5, %r10, 4; + add.s64 %rd1, %rd3, %rd5; + .loc 1 24 35 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:24:35 + // begin inline asm + mov.u32 %r1, 0x0; + mov.u32 %r2, 0x0; + @%p1 ld.global.v2.b32 { %r1, %r2 }, [ %rd1 + 0 ]; + // end inline asm + .loc 1 26 18 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:26:18 + mul.f32 %r3, %r1, 0f3F317218; + mul.f32 %r4, %r2, 0f3F317218; + .loc 1 27 25 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:27:25 + add.s64 %rd2, %rd4, %rd5; + .loc 1 27 36 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:27:36 + // begin inline asm + @%p1 st.global.v2.b32 [ %rd2 + 0 ], { %r3, %r4 }; + // end inline asm + .loc 1 27 4 // ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py:27:4 + ret; +$L__tmp1: +$L__func_end0: + // -- End function +} + .file 1 "/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py" + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 0 // DW_CHILDREN_no +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } + .section .debug_info + { +.b32 139 // Length of Unit +.b8 2 // DWARF version number +.b8 0 +.b32 .debug_abbrev // Offset Into Abbrev. Section +.b8 8 // Address Size (in bytes) +.b8 1 // Abbrev [1] 0xb:0x84 DW_TAG_compile_unit +.b8 116 // DW_AT_producer +.b8 114 +.b8 105 +.b8 116 +.b8 111 +.b8 110 +.b8 0 +.b8 2 // DW_AT_language +.b8 0 +.b8 99 // DW_AT_name +.b8 116 +.b8 102 +.b8 109 +.b8 103 +.b8 114 +.b8 53 +.b8 120 +.b8 105 +.b8 101 +.b8 115 +.b8 112 +.b8 118 +.b8 122 +.b8 105 +.b8 106 +.b8 114 +.b8 109 +.b8 104 +.b8 103 +.b8 98 +.b8 97 +.b8 108 +.b8 55 +.b8 53 +.b8 114 +.b8 50 +.b8 117 +.b8 112 +.b8 112 +.b8 54 +.b8 104 +.b8 99 +.b8 97 +.b8 108 +.b8 98 +.b8 104 +.b8 98 +.b8 108 +.b8 110 +.b8 117 +.b8 103 +.b8 112 +.b8 101 +.b8 106 +.b8 103 +.b8 105 +.b8 112 +.b8 120 +.b8 108 +.b8 114 +.b8 120 +.b8 46 +.b8 112 +.b8 121 +.b8 0 +.b32 .debug_line // DW_AT_stmt_list +.b8 47 // DW_AT_comp_dir +.b8 119 +.b8 111 +.b8 114 +.b8 107 +.b8 115 +.b8 112 +.b8 97 +.b8 99 +.b8 101 +.b8 47 +.b8 104 +.b8 97 +.b8 110 +.b8 114 +.b8 117 +.b8 105 +.b8 47 +.b8 106 +.b8 117 +.b8 110 +.b8 113 +.b8 117 +.b8 97 +.b8 110 +.b8 47 +.b8 83 +.b8 112 +.b8 101 +.b8 99 +.b8 70 +.b8 111 +.b8 114 +.b8 103 +.b8 101 +.b8 47 +.b8 99 +.b8 97 +.b8 99 +.b8 104 +.b8 101 +.b8 47 +.b8 99 +.b8 111 +.b8 109 +.b8 112 +.b8 105 +.b8 108 +.b8 101 +.b8 100 +.b8 95 +.b8 107 +.b8 101 +.b8 114 +.b8 110 +.b8 101 +.b8 108 +.b8 115 +.b8 47 +.b8 116 +.b8 102 +.b8 0 + } + .section .debug_macinfo { } diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.source b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.source new file mode 100644 index 0000000000000000000000000000000000000000..253b3e8275a8b332da2684adee17fc287226242b --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.source @@ -0,0 +1,51 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":18:0) +#loc14 = loc("in_ptr0"(#loc)) +#loc15 = loc("out_ptr0"(#loc)) +#loc16 = loc("xnumel"(#loc)) +module { + tt.func public @triton_poi_fused_mul_1(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc))) attributes {noinline = false} { + %xnumel_0 = arith.constant 31232 : i32 loc(#loc17) + %xoffset = tt.get_program_id x : i32 loc(#loc18) + %xoffset_1 = arith.constant 256 : i32 loc(#loc19) + %xoffset_2 = arith.constant 256 : i32 loc(#loc19) + %xoffset_3 = arith.muli %xoffset, %xoffset_2 : i32 loc(#loc19) + %xindex = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc20) + %xindex_4 = tt.splat %xoffset_3 : i32 -> tensor<256xi32> loc(#loc21) + %xindex_5 = arith.addi %xindex_4, %xindex : tensor<256xi32> loc(#loc21) + %xmask = arith.constant dense<31232> : tensor<256xi32> loc(#loc22) + %xmask_6 = arith.cmpi slt, %xindex_5, %xmask : tensor<256xi32> loc(#loc22) + %tmp0 = tt.splat %in_ptr0 : !tt.ptr -> tensor<256x!tt.ptr> loc(#loc23) + %tmp0_7 = tt.addptr %tmp0, %xindex_5 : tensor<256x!tt.ptr>, tensor<256xi32> loc(#loc23) + %tmp0_8 = tt.load %tmp0_7, %xmask_6 : tensor<256x!tt.ptr> loc(#loc24) + %tmp1 = arith.constant 0.693147182 : f32 loc(#loc25) + %tmp2 = arith.constant dense<0.693147182> : tensor<256xf32> loc(#loc26) + %tmp2_9 = arith.mulf %tmp0_8, %tmp2 : tensor<256xf32> loc(#loc26) + %0 = tt.splat %out_ptr0 : !tt.ptr -> tensor<256x!tt.ptr> loc(#loc11) + %1 = tt.addptr %0, %xindex_5 : tensor<256x!tt.ptr>, tensor<256xi32> loc(#loc11) + tt.store %1, %tmp2_9, %xmask_6 : tensor<256x!tt.ptr> loc(#loc12) + tt.return loc(#loc13) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":19:13) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:28) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:33) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:36) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:23) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":22:21) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:30) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:35) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":25:11) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":26:18) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:25) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:36) +#loc13 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:4) +#loc17 = loc("xnumel"(#loc1)) +#loc18 = loc("xoffset"(#loc2)) +#loc19 = loc("xoffset"(#loc3)) +#loc20 = loc("xindex"(#loc4)) +#loc21 = loc("xindex"(#loc5)) +#loc22 = loc("xmask"(#loc6)) +#loc23 = loc("tmp0"(#loc7)) +#loc24 = loc("tmp0"(#loc8)) +#loc25 = loc("tmp1"(#loc9)) +#loc26 = loc("tmp2"(#loc10)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttgir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttgir new file mode 100644 index 0000000000000000000000000000000000000000..d21383e7c322662a9227a989c57cdbe488296d9e --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttgir @@ -0,0 +1,46 @@ +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":18:0) +#loc13 = loc("in_ptr0"(#loc)) +#loc14 = loc("out_ptr0"(#loc)) +#loc15 = loc("xnumel"(#loc)) +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @triton_poi_fused_mul_1(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc))) attributes {noinline = false} { + %cst = arith.constant dense<31232> : tensor<256xi32, #blocked> loc(#loc1) + %c256_i32 = arith.constant 256 : i32 loc(#loc1) + %cst_0 = arith.constant dense<0.693147182> : tensor<256xf32, #blocked> loc(#loc1) + %xoffset = tt.get_program_id x : i32 loc(#loc16) + %xoffset_1 = arith.muli %xoffset, %c256_i32 : i32 loc(#loc17) + %xindex = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> loc(#loc18) + %xindex_2 = tt.splat %xoffset_1 : i32 -> tensor<256xi32, #blocked> loc(#loc19) + %xindex_3 = arith.addi %xindex_2, %xindex : tensor<256xi32, #blocked> loc(#loc19) + %xmask = arith.cmpi slt, %xindex_3, %cst : tensor<256xi32, #blocked> loc(#loc20) + %tmp0 = tt.splat %in_ptr0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked> loc(#loc21) + %tmp0_4 = tt.addptr %tmp0, %xindex_3 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> loc(#loc21) + %tmp0_5 = tt.load %tmp0_4, %xmask : tensor<256x!tt.ptr, #blocked> loc(#loc22) + %tmp2 = arith.mulf %tmp0_5, %cst_0 : tensor<256xf32, #blocked> loc(#loc23) + %0 = tt.splat %out_ptr0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked> loc(#loc10) + %1 = tt.addptr %0, %xindex_3 : tensor<256x!tt.ptr, #blocked>, tensor<256xi32, #blocked> loc(#loc10) + tt.store %1, %tmp2, %xmask : tensor<256x!tt.ptr, #blocked> loc(#loc11) + tt.return loc(#loc12) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:28) +#loc3 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:33) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:36) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:23) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":22:21) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:30) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:35) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":26:18) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:25) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:36) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:4) +#loc16 = loc("xoffset"(#loc2)) +#loc17 = loc("xoffset"(#loc3)) +#loc18 = loc("xindex"(#loc4)) +#loc19 = loc("xindex"(#loc5)) +#loc20 = loc("xmask"(#loc6)) +#loc21 = loc("tmp0"(#loc7)) +#loc22 = loc("tmp0"(#loc8)) +#loc23 = loc("tmp2"(#loc9)) diff --git a/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttir b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttir new file mode 100644 index 0000000000000000000000000000000000000000..0220cce0a6843a8257886db196edb50af3ea16c8 --- /dev/null +++ b/progress/github/SpecForge/cache/compiled_kernels/triton/7/R5RELHCRDUS6C7SJCJ42TZV5UZMVMY6IASMWGQCRUN553QFBJJEA/triton_poi_fused_mul_1.ttir @@ -0,0 +1,45 @@ +#loc = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":18:0) +#loc13 = loc("in_ptr0"(#loc)) +#loc14 = loc("out_ptr0"(#loc)) +#loc15 = loc("xnumel"(#loc)) +module { + tt.func public @triton_poi_fused_mul_1(%in_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("in_ptr0"(#loc)), %out_ptr0: !tt.ptr {tt.divisibility = 16 : i32} loc("out_ptr0"(#loc)), %xnumel: i32 {tt.divisibility = 16 : i32} loc("xnumel"(#loc))) attributes {noinline = false} { + %tmp2 = arith.constant dense<0.693147182> : tensor<256xf32> loc(#loc16) + %xmask = arith.constant dense<31232> : tensor<256xi32> loc(#loc17) + %c256_i32 = arith.constant 256 : i32 loc(#loc3) + %xoffset = tt.get_program_id x : i32 loc(#loc18) + %xoffset_0 = arith.muli %xoffset, %c256_i32 : i32 loc(#loc19) + %xindex = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc20) + %xindex_1 = tt.splat %xoffset_0 : i32 -> tensor<256xi32> loc(#loc21) + %xindex_2 = arith.addi %xindex_1, %xindex : tensor<256xi32> loc(#loc21) + %xmask_3 = arith.cmpi slt, %xindex_2, %xmask : tensor<256xi32> loc(#loc17) + %tmp0 = tt.splat %in_ptr0 : !tt.ptr -> tensor<256x!tt.ptr> loc(#loc22) + %tmp0_4 = tt.addptr %tmp0, %xindex_2 : tensor<256x!tt.ptr>, tensor<256xi32> loc(#loc22) + %tmp0_5 = tt.load %tmp0_4, %xmask_3 : tensor<256x!tt.ptr> loc(#loc23) + %tmp2_6 = arith.mulf %tmp0_5, %tmp2 : tensor<256xf32> loc(#loc16) + %0 = tt.splat %out_ptr0 : !tt.ptr -> tensor<256x!tt.ptr> loc(#loc10) + %1 = tt.addptr %0, %xindex_2 : tensor<256x!tt.ptr>, tensor<256xi32> loc(#loc10) + tt.store %1, %tmp2_6, %xmask_3 : tensor<256x!tt.ptr> loc(#loc11) + tt.return loc(#loc12) + } loc(#loc) +} loc(#loc) +#loc1 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":26:18) +#loc2 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":22:21) +#loc3 = loc(unknown) +#loc4 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:28) +#loc5 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":20:33) +#loc6 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:36) +#loc7 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":21:23) +#loc8 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:30) +#loc9 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":24:35) +#loc10 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:25) +#loc11 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:36) +#loc12 = loc("/workspace/hanrui/junquan/SpecForge/cache/compiled_kernels/tf/ctfmgr5xiespvzijrmhgbal75r2upp6hcalbhblnugpejgipxlrx.py":27:4) +#loc16 = loc("tmp2"(#loc1)) +#loc17 = loc("xmask"(#loc2)) +#loc18 = loc("xoffset"(#loc4)) +#loc19 = loc("xoffset"(#loc5)) +#loc20 = loc("xindex"(#loc6)) +#loc21 = loc("xindex"(#loc7)) +#loc22 = loc("tmp0"(#loc8)) +#loc23 = loc("tmp0"(#loc9)) diff --git a/progress/github/SpecForge/docs/_static/css/custom_log.css b/progress/github/SpecForge/docs/_static/css/custom_log.css new file mode 100644 index 0000000000000000000000000000000000000000..61f65d0199df9e97886560f7f97c6c9b026bd34e --- /dev/null +++ b/progress/github/SpecForge/docs/_static/css/custom_log.css @@ -0,0 +1,29 @@ +.output_area { + color: #615656; +} + +table.autosummary td { + width: 50% + } + + img.align-center { + display: block; + margin-left: auto; + margin-right: auto; +} + +.output_area.stderr { + color: #d3d3d3 !important; +} + +.output_area.stdout { + color: #d3d3d3 !important; +} + +div.output_area.stderr { + color: #d3d3d3 !important; +} + +div.output_area.stdout { + color: #d3d3d3 !important; +} diff --git a/progress/github/SpecForge/docs/_static/css/readthedocs.css b/progress/github/SpecForge/docs/_static/css/readthedocs.css new file mode 100644 index 0000000000000000000000000000000000000000..aca6649b436a35cf39b2c924ce2f74ed2cdc8b90 --- /dev/null +++ b/progress/github/SpecForge/docs/_static/css/readthedocs.css @@ -0,0 +1,9 @@ +table.autosummary td { + width: 50% +} + +img.align-center { + display: block; + margin-left: auto; + margin-right: auto; +} diff --git a/progress/github/SpecForge/docs/advanced_features/customization.md b/progress/github/SpecForge/docs/advanced_features/customization.md new file mode 100644 index 0000000000000000000000000000000000000000..47b624a9ce461b5f37aa6c159bcb306657c68ed4 --- /dev/null +++ b/progress/github/SpecForge/docs/advanced_features/customization.md @@ -0,0 +1,118 @@ +# 💡 Customize Your Own Training + +## 🔧 Customize Training Args + +```bash +torchrun \ + --standalone \ + --nproc_per_node 8 \ + ./scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config ./configs/llama3-8B-eagle3.json \ + --train-data-path ./cache/dataset/sharegpt.jsonl \ + --output-dir ./outputs/llama3-8b-eagle3 \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template llama3 \ + --cache-dir ./cache +``` + +If you wish to understand what each argument does, you can run `python scripts/train_eagle3.py --help` to see the full list of arguments. Particularly, we will discuss some important arguments below. +- `--chat-template`: This should be the chat template to use for the model, so please make sure you set it to the correct value. +- `--cache-dir`: This directory contains the dataset cache including the `input_ids`, `loss_mask`, `attention_mask` and `vocab_mapping`. These caches can make your data loading much faster once a cache is generated. The cache file has a name which is obtained by hashing the dataset path to avoid cache collision. + +## 💬 Customize Chat Template + +You can register a new chat template for your model by adding a new entry to the `TEMPLATE_REGISTRY` in the `specforge.data.template.py` file. + +```python +TEMPLATE_REGISTRY.register( + name="your-template-name", + template=ChatTemplate( + assistant_header="xxx", + user_header="xxx", + system_prompt="xxx", + end_of_turn_token="xxx", + ), +) +``` + +## 🪅 Customize Model + +### Customize Target Model + +If you wish to train Eagle3 for other models, you need to modify the `--target-model-path` value. We support loading these models directly from HuggingFace. + +However, if your model is too large and requires tensor parallelism, you can implement its tensor parallel version on your own in the `specforge.modeling.target` directory. The CausalLM model should inherit the `DistributedTargetModel` class in the `specforge.modeling.target.base.py` file and apply `ColumnParallelLinear` and `RowParallelLinear` to its submodules. + +```python +from .base import DistributedTargetModel +from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear + + +class MyModelForCausalLM(MyModelPreTrainedModel, GenerationMixin, DistributedTargetModel): + ... + + def load_weights(self, state_dict: Dict[str, torch.Tensor]): + ... +``` + +Afterwards, you need to register this model to the `AutoEagle3TargetModel` class in the `specforge.modeling.auto.py` file. + +```diff +class AutoDistributedTargetModel(AutoModelForCausalLMBase): + _model_mapping = { + Llama4TextConfig: [Llama4ForCausalLM], ++ MyModelConfig: [MyModelForCausalLM], + } +``` + +When `tp_size` is greater than 1, the script will automatically load the distributed version of the model for tensor parallelism. + +### Customize Draft Model + +If you want to change the draft model configuration, you can write your own configuration file and pass its path to the `--draft-model-config` argument. Or, if you do not provide the `--draft-model-config` argument, the script will automatically generate the draft model configuration based on the target model configuration. If you wish to serve your customized draft model with SGLang, make sure you implement the draft model in SGLang as well and the architecture name must match. To implement your own draft model, you can create a new class and inherit it from the `Eagle3DraftModel` class in the `specforge.modeling.draft.base.py` file. + + +```python +from .base import Eagle3DraftModel +from transformers import PretrainedConfig + + +class MyModelConfig(PretrainedConfig): + model_type = "mymodel" + + def __init__(self, **kwargs): + ... + + +class MyModelEagle3(Eagle3DraftModel): + + config_class = MyModelConfig + + def __init__(self, config, quant_config=None) -> None: + ... +``` + +You can then register these models to the `AutoEagle3TargetModel` and `AutoDraftModelConfig` classes in the `specforge.modeling.auto.py` file for the automatic model loading. + +```diff +class AutoEagle3DraftModel(AutoModelForCausalLMBase): + # the model mapping is currently hardcoded, we should support lazy model mapping via registry + _model_mapping = { + LlamaConfig: [LlamaForCausalLMEagle3], ++ MyModelConfig: MyModelEagle3, + } + + +class AutoDraftModelConfig: + + _config_mapping = { + "LlamaForCausalLMEagle3": LlamaConfig, ++ "MyModelEagle3": MyModelConfig, + } +``` + +In this way, as long as your `config.json` specifies the correct architecture name, the script will automatically load the correct draft model for you. diff --git a/progress/github/SpecForge/docs/basic_usage/training.md b/progress/github/SpecForge/docs/basic_usage/training.md new file mode 100644 index 0000000000000000000000000000000000000000..a41b5a0dee1a9a12620f25ae26f613f4711d0b7c --- /dev/null +++ b/progress/github/SpecForge/docs/basic_usage/training.md @@ -0,0 +1,62 @@ +## 🚀 Training + +## 📍 Overview + +Existing speculative decoding methods such as EAGLE3 requires training in the feature-space, which means the draft model relies on the hidden states generated from the target model for autoregressive prediction. In SpecForge, we provide two orthogonal paths to cater to the users' specific needs when training this kind of draft models. We name these two methods as `Online` and `Offline`. By definition, it is easy to understandd them: + +- **`Online`**: the hidden states are generated on the fly during training. +- **`Offline`**: the hidden states are generated beforehand, stored to the disk, and loaded back to GPU during training. + +Online training is suitable for users with limited disk space but sufficient GPUs while offline training is suitable for users with sufficient disk space but limited GPUs. + +| Method | Target Model | Disk Space Requirement | GPU Requirement | One-liner rationale | +| --- | --- | --- | --- | --- | +| Online | Used during training | Small | More GPUs are needed if your target model is large | Generating auxiliary hidden states on the fly | +| Offline | Only used during data preparation | Huge (e.g. ultrachat+sharegpt will need 12TB storage ) | as low as 1 GPU, as only need to accommodate the draft model | Preparing auxiliary hidden states beforehand and only once | + +> **Why does disk matter?** +> During Eagle3 training, the frozen target model will first generate the hidden states for each token given the data sample. The hidden states are fed to the draft model for training. +> Offline mode stores these hidden states to the local disk, so a small disk can be filled up fast. +> Online mode only generates these hidden states on the fly without storing them to the disk, but needs to keep the target model resident in memory during training, trading GPU RAM for almost-zero disk footprint. + +## 🏎️ Online Training + +We have provided training scripts for the EAGLE3 models in the `examples` directory. These scripts cover a wide range of models range from Llama to Qwen, small to large and dense to MoE. Online training is often conducted in two steps and we will use ShareGPT and Llama3-8B-Instruct as an example. + +**Step 1: Prepare the dataset** + +```bash +# prepare the dataset +python scripts/prepare_data.py --dataset sharegpt +``` + +**Step 2: Start the training** + +```bash +# train llama3-8B-instruct +bash ./examples/run_llama3.1_8b_eagle3_online.sh +``` + +## 💨 Offline Training + +The difference between online and offline training is that we need to generate the hidden states before training. We also use ShareGPT and Llama3-8B-Instruct as an example. + +**Step 1: Prepare the dataset** + +Same as above + +**Step 2: Generate the hidden states and train** + +```bash +# train llama3-8B-instruct in an offline manner +bash ./examples/run_llama3.1_8b_eagle3_offline.sh +``` + +It is important to note that the `run_llama3.1_8b_eagle3_offline.sh` script consists of two steps: + +1. Generate the hidden states using the `prepare_hidden_states.py` script. This script will generate the hidden states for the test and train datasets and save them to the disk. +2. Train the model: suppling the `--train-hidden-states-path` argument to the script so that the script will load the hidden states from the disk during training. + +## 📈 Experiment Tracking + +This project supports logging training progress to Wandb, TensorBoard, and SwanLab. You can enable tracking by adding the `--report-to` argument to the command line in your shell script. diff --git a/progress/github/SpecForge/docs/benchmarks/benchmark.md b/progress/github/SpecForge/docs/benchmarks/benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..29a51b35d5d7639ebd666202aad3377063e4ee12 --- /dev/null +++ b/progress/github/SpecForge/docs/benchmarks/benchmark.md @@ -0,0 +1,67 @@ +# Benchmarking for Speculative Decoding + +## Overview + +We provide a unified script to test the performance of the Speculative Decoding with EAGLE3 algorithm on multiple datasets. You can follow the steps below to run the benchmarks. + +## Run Benchmarks + +### Launch SGLang and Benchmarker Concurrently + +`bench_eagle3.py` can help you launch a SGLang server process and a Benchmarking process concurrently. In this way, you don't have to launch the SGLang server manually, this script will manually handle the SGLang launch under different speculative decoding configurations. Some important arguments are: +- `--model-path`: the path to the target model. +- `--speculative-draft-model-path`: the path to the draft model. +- `--port`: the port to launch the SGLang server. +- `--trust-remote-code`: trust the remote code. +- `--mem-fraction-static`: the memory fraction for the static memory. +- `--tp-size`: the tensor parallelism size. +- `--attention-backend`: the attention backend. +- `--config-list`: the list of speculative decoding configuration to test, the format is `,,,`. +- `--benchmark-list`: the list of benchmarks to test, the format is `::`. + +```shell +python3 bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --port 30000 \ + --trust-remote-code \ + --mem-fraction-static 0.8 \ + --tp-size 1 \ + --attention-backend fa3 \ + --config-list 1,0,0,0 1,3,1,4 \ + --benchmark-list mtbench gsm8k:5 ceval:5:accountant \ + --dtype bfloat16 +``` + +### Launch Benchmarker Independently + +If you want to launch the SGLang server independently, you can use the following command. + +```shell +# you can launch a server +python3 -m sglang.launch_server \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 \ + --mem-fraction-static 0.75 \ + --cuda-graph-max-bs 1 \ + --tp 1 \ + --trust-remote-code \ + --host 0.0.0.0 \ + --port 30000 \ + --dtype bfloat16 +``` + +Then we can start benchmarking. Note that you should use the same host and port as the one used in the SGLang server. Note that `--skip-launch-server` is required to skip the launch of the SGLang server. + +```bash +python bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --port 30000 \ + --config-list 1,3,1,4 \ + --benchmark-list mtbench:5 ceval:5:accountant gsm8k:5 humaneval:5 math500:5 mtbench:5 aime:1 \ + --skip-launch-server +``` diff --git a/progress/github/SpecForge/docs/community_resources/specbundle.md b/progress/github/SpecForge/docs/community_resources/specbundle.md new file mode 100644 index 0000000000000000000000000000000000000000..5efb84e84e72a98be42ce445eff6c0a5e7d6bcda --- /dev/null +++ b/progress/github/SpecForge/docs/community_resources/specbundle.md @@ -0,0 +1,93 @@ +# 🔥 SpecBundle + +
+ specbundle logo +
+ + +## About SpecBundle + +Speculative decoding, especially EAGLE3, offer strong theoretical guarantees alongside consistent empirical improvements in token acceptance rate and end-to-end inference speed. However, despite these advances, adoption of speculative decoding—especially EAGLE3—remains limited in the open-source ecosystem, due primarily to three key factors. + +1. Lack of production-ready training infrastructure: Existing speculative decoding toolchains are largely research prototypes, offering limited system-level optimization and inadequate support for diverse architectures and large-scale models. +2. Scarcity of high-quality draft models: Effective speculative decoding depends on strong draft models, yet publicly available EAGLE3-compatible checkpoints are extremely limited, primarily originating from the original authors. +3. Insufficient training scale of existing drafts: Most available draft models are trained on small or curated datasets and fail to generalize to the large, diverse corpora used in modern LLM training, resulting in low token acceptance rates and diminished practical speedups. + +**SpecBundle** is a direct response to these limitations. Jointly driven by the open-source community and industry partners including **Ant Group**, **Meituan**, **Nex-AGI** and **EigenAI**, **SpecBundle** represents the **first open initiative** aimed at democratizing speculative decoding by providing high-performance, production-grade EAGLE3 draft model weights for mainstream open-source LLMs. This initiative also serves to verify the robustness of the **SpecForge** framework through multiple scales and architectures. + +We call for all open-source developers and industry partners to join this exciting initiative. + +## Performance Scores + +We evaluate the performance of SpecBundle draft models on various benchmarks, please visit the [Performance Dashboard](https://docs.sglang.io/SpecForge/SpecBundle/index.html) for more details. + +## Usage + +You can use the following command to launch the SGLang server with SpecBundle models. Please add `--tp`, `--ep` and `--mem-fraction-static` arguments when you encounter memory issues. + +```bash +python3 -m sglang.launch_server \ + --model \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path \ + --speculative-num-steps 3 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 4 +``` + +## Released Models + +We list the models released by the SpecForge and several industrial partners below. These models are released as part of the SpecBundle models, which are trained on large-scale multi-domain datasets and deliver exceptional performance on various benchmarks. + +> We also include some of the models previously trained by the SpecForge team but not technically part of the SpecBundle release. +> We mark models trained on ShareGPT+Ultrachat datasets with a **\*** mark and models trained on Perfect-Blend datasets but released before SpecBundle with **+** mark. + +### Llama Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| meta-llama/Llama-3.1-8B-Instruct | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge) | [🤗 Dataset](https://huggingface.co/datasets/frankleeeee/PerfectBlend-Regenerated-Llama-3.1-8B-Instruct) | +| meta-llama/Llama-3.3-70B-Instruct | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge) | [🤗 Dataset](https://huggingface.co/datasets/frankleeeee/PerfectBlend-Regenerated-Llama-3.3-70B-Instruct) | +| meta-llama/Llama-4-Scout-17B-16E-Instruct | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge) | [🤗 Dataset](https://huggingface.co/datasets/frankleeeee/PerfectBlend-Regenerated-Llama-4-Scout-17B-16E-Instruct) | +| meta-llama/Llama-4-Maverick-17B-128E-Instruct | [🤗 Model *](https://huggingface.co/lmsys/sglang-EAGLE3-Llama-4-Maverick-17B-128E-Instruct-v1) | [🤗 Dataset](https://huggingface.co/datasets/frankleeeee/PerfectBlend-Regenerated-Llama-4-Maverick-17B-128E-Instruct) | + +### Qwen Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| Qwen/Qwen3-30B-A3B-Instruct-2507 | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge-Nex) | [🤗 Dataset](https://huggingface.co/datasets/lukeysong/qwen-30b-regen-blend) | +| Qwen/Qwen3-235B-A22B-Instruct-2507 | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge-Meituan) | [🤗 Dataset](https://huggingface.co/datasets/lukeysong/qwen3-235-regen-perfect_blend) | +| Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-perfect-blend-regenerated) | [🤗 Dataset](https://huggingface.co/datasets/lukeysong/qwen3-80b-regen-prefectblend) | + +### Qwen Coder Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| Qwen/Qwen3-Coder-30B-A3B-Instruct | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge) | [🤗 Dataset](https://huggingface.co/datasets/JinnP/opc_regen_Qwen3-Coder-30B-A3B-Instruct) | +| Qwen/Qwen3-Coder-480B-A35B-Instruct | [🤗 Model](https://huggingface.co/lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI) | - | + +### Ling Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| inclusionAI/Ling-flash-2.0 | [🤗 Model](https://huggingface.co/AQ-MedAI/Ling-Flash-2.0-eagle3) | - | + +### Kimi Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| moonshotai/Kimi-K2-Instruct | [🤗 Model](https://huggingface.co/AQ-MedAI/Kimi-K2-Instruct-eagle3) | - | + +### GPT-OSS Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| openai/gpt-oss-20b | [🤗 Model +](https://huggingface.co/zhuyksir/EAGLE3-gpt-oss-20b-bf16) | [🤗 Dataset](https://huggingface.co/datasets/zhuyksir/perfect-blend-gptoss-20B-1M) | +| openai/gpt-oss-120b | [🤗 Model +](https://huggingface.co/lmsys/EAGLE3-gpt-oss-120b-bf16) | - | + +### Nex Series + +| Target Model | EAGLE3 Draft Model | Regenerated Dataset | +|---------------|--------------------|--------------------| +| nex-agi/Qwen3-30B-A3B-Nex-N1 | [🤗 Model](https://huggingface.co/nex-agi/SGLANG-EAGLE3-Qwen3-30B-A3B-Nex-N1) | - | +| nex-agi/Qwen3-32B-Nex-N1 | [🤗 Model](https://huggingface.co/nex-agi/SGLANG-EAGLE3-Qwen3-32B-Nex-N1) | - | diff --git a/progress/github/SpecForge/docs/examples/llama3-eagle3-offline.md b/progress/github/SpecForge/docs/examples/llama3-eagle3-offline.md new file mode 100644 index 0000000000000000000000000000000000000000..a8449cc612073c9ca76804891ca9339fad9c8ca1 --- /dev/null +++ b/progress/github/SpecForge/docs/examples/llama3-eagle3-offline.md @@ -0,0 +1,57 @@ +# Eagle3 for Llama3 - Offline + +## Introduction + +This document provides a step-by-step guide on how to train the EAGLE3 model for the Llama3.1-8B-Instruct model in an offline manner. In offline training, we generate the hidden states required by EAGLE3 draft model beforehand and store them to the disk. During training, we load them back to the GPU memory. As offline training requires a lot of disk space, we do not recommend running this on large datasets such as Perfect-Blend. + +## Training on ShareGPT dataset + +### **Step 1. Prepare ShareGPT dataset** + +First of all, we should download the dataset. + +```shell +python ./scripts/prepare_data.py --dataset sharegpt +``` + +### **Step 2. Prepare Hidden States** + +We need to prepare the hidden states for the training. + +```shell +torchrun \ + --standalone \ + --nproc_per_node 8 \ + scripts/prepare_hidden_states.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --enable-aux-hidden-states \ + --data-path ./cache/dataset/sharegpt_train.jsonl \ + --output-path ./cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --chat-template llama3 \ + --max-length 4096 \ + --tp-size 1 \ + --batch-size 32 +``` + +The hidden states will be saved to the disk in the `output-path` directory. + +### **Step 3. Start Training** + +```shell +torchrun \ + --standalone \ + --nproc_per_node 8 \ + ./scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config ./configs/llama3-8B-eagle3.json \ + --train-data-path ./cache/dataset/sharegpt_train.jsonl \ + --train-hidden-states-path ./cache/hidden_states/sharegpt_train_Llama-3.1-8B-Instruct \ + --output-dir ./outputs/llama3-8b-eagle3-sharegpt-offline \ + --num-epochs 10 \ + --batch-size 1 \ + --tp-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --cache-dir ./cache +``` diff --git a/progress/github/SpecForge/docs/examples/llama3-eagle3-online.md b/progress/github/SpecForge/docs/examples/llama3-eagle3-online.md new file mode 100644 index 0000000000000000000000000000000000000000..13dd2fdd1c9ed9b9f06505a52b5db272c5a3bd49 --- /dev/null +++ b/progress/github/SpecForge/docs/examples/llama3-eagle3-online.md @@ -0,0 +1,75 @@ +# Eagle3 for Llama3 - Online + +## Introduction + +This document provides a step-by-step guide on how to train the EAGLE3 model for the Llama3.1-8B-Instruct model in an online manner. In online training, we generate the hidden states required by EAGLE3 draft model on the fly during training. This example is using `ShareGPT` dataset for training, the performance is not optimal due to the size and limited coverage of the dataset. If you look for optimal performance, we recommend you to try more diverse datasets such as [`Perfect-Blend`](https://huggingface.co/datasets/facebook/perfect-blend). We have also included a section on training on `Perfect-Blend` dataset at the end of this document. + + +## Training on ShareGPT dataset + +### **Step 1. Prepare ShareGPT dataset** + +First of all, we should download the dataset. + +```shell +python ./scripts/prepare_data.py --dataset sharegpt +``` + +### **Step 2. Launch Online Training** + +```shell +torchrun \ + --standalone \ + --nproc_per_node 8 \ + scripts/train_eagle3.py \ + --target-model-path meta-llama/Llama-3.1-8B-Instruct \ + --draft-model-config configs/llama3-8B-eagle3.json \ + --train-data-path ./cache/dataset/sharegpt_train.jsonl \ + --output-dir ./outputs/llama3-8b-eagle3 \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template llama3 \ + --target-model-backend sglang \ +``` + +### **Step 3. Benchmark** + +For `Llama3.1-8B`, we add a system prompt to all training data, following the approach used in the official repository. Consequently, when benchmarking, we should also include this system prompt to obtain the full accept length. Please uncomment the corresponding line and add the system prompt. + +The four numbers in the config represent: `batch_size, num_steps, topk, num_verify_tokens`. You can adjust the values in the config list to experiment with different test cases. + +A pre-trained EAGLE model is available at [zhuyksir/EAGLE3-Llama-3.1-8B-Instruct](https://huggingface.co/zhuyksir/EAGLE3-Llama-3.1-8B-Instruct) for reference. + +```shell +cd benchmarks + +config_list=( + "4,3,1,4" + "4,7,10,60" +) +python3 bench_eagle3.py \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --speculative-draft-model-path /YOUR/PATH/Llama-3.1-8B-Instruct/dev_outputs/epoch_0 \ + --port 30000 \ + --mem-fraction-static 0.8 \ + --tp-size 1 \ + --config-list "${config_list[@]}" \ + --benchmark-list mtbench gsm8k humaneval math500 +``` + + +## Training on Perfect-Blend dataset + +### **Step 1. Prepare Perfect-Blend dataset** + +First of all, we should download the dataset. + +```shell +python ./scripts/prepare_data.py --dataset perfectblend +``` + +### **Step 2. Launch Online Training** + +We just need to change the `--train-data-path` to the path of the Perfect-Blend dataset (e.g. `./cache/dataset/perfectblend_train.jsonl`), then we can launch training smoothly. diff --git a/progress/github/SpecForge/docs/spec_bundle/index.html b/progress/github/SpecForge/docs/spec_bundle/index.html new file mode 100644 index 0000000000000000000000000000000000000000..ad336a93a9fd136ac55768562e96a1d8f324d001 --- /dev/null +++ b/progress/github/SpecForge/docs/spec_bundle/index.html @@ -0,0 +1,21 @@ + + + + + + + + + + + SpecBundle + + + +
+ + + + diff --git a/progress/github/SpecForge/docs/spec_bundle/public/raw_data/data.json b/progress/github/SpecForge/docs/spec_bundle/public/raw_data/data.json new file mode 100644 index 0000000000000000000000000000000000000000..f923184be9f11a0be51daa926e7b94ff1821007b --- /dev/null +++ b/progress/github/SpecForge/docs/spec_bundle/public/raw_data/data.json @@ -0,0 +1,6422 @@ +{ + "Qwen3-30B-A3B-Instruct-2507": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1071.2940027174511, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1488.3645940190918, + "accept_length": 2.6400593352844486 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1071.2940027174511, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1499.6157892300257, + "accept_length": 3.0113471715954674 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1071.2940027174511, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1491.1759364152986, + "accept_length": 2.525104073618391 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1071.2940027174511, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1438.3989235515564, + "accept_length": 3.1488859094681736 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1071.2940027174511, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1478.3371126866896, + "accept_length": 2.515156901620291 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1468.9518188983302, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3022.302541558449, + "accept_length": 3.4018400160943374 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1468.9518188983302, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3458.7683757488517, + "accept_length": 4.5001277922609 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1468.9518188983302, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2710.0700446913434, + "accept_length": 3.83069810232181 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1468.9518188983302, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3636.1457092511932, + "accept_length": 5.29297884876688 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1468.9518188983302, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2650.9994915668844, + "accept_length": 3.981701201346221 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1341.3462205459145, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2048.689292397081, + "accept_length": 2.495847913511255 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1341.3462205459145, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2086.117426859236, + "accept_length": 2.831051301639537 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1341.3462205459145, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1698.4151046745978, + "accept_length": 2.5572219713355357 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1341.3462205459145, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1998.1600180425269, + "accept_length": 2.9819193324061195 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1341.3462205459145, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1742.9797705522778, + "accept_length": 2.7422317575874455 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1366.6183006362219, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2618.165602951494, + "accept_length": 3.349328692192939 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1366.6183006362219, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2912.1392571686956, + "accept_length": 4.384426363785289 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1366.6183006362219, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2367.016477367958, + "accept_length": 3.7901897758795298 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1366.6183006362219, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3069.9815866099266, + "accept_length": 5.124267515923567 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1366.6183006362219, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2363.3377665362655, + "accept_length": 4.030938739532834 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1492.6190597361915, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2911.405162351629, + "accept_length": 3.1783624121672447 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1492.6190597361915, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3265.2547245227543, + "accept_length": 4.018270197787462 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1492.6190597361915, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2455.0885550482017, + "accept_length": 3.295517305362425 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1492.6190597361915, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 3413.029275629196, + "accept_length": 4.576331556763159 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1492.6190597361915, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2355.0941391264764, + "accept_length": 3.3973067623684012 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1320.1266846132082, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1778.9653109324079, + "accept_length": 2.0810309937160505 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1320.1266846132082, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1778.6778684706662, + "accept_length": 2.2730321793789288 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1320.1266846132082, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1652.1607344416184, + "accept_length": 2.2703352879266276 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1320.1266846132082, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1682.9566856293293, + "accept_length": 2.3032779273841584 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1320.1266846132082, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1753.6698041448958, + "accept_length": 2.6092096546804138 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1410.428038868636, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2237.792328921565, + "accept_length": 2.5958448251993995 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1410.428038868636, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2341.298191039886, + "accept_length": 3.0077922694984913 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1410.428038868636, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 1961.1700111065113, + "accept_length": 2.6947097860315505 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1410.428038868636, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2310.2053834681674, + "accept_length": 3.216540452331778 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1410.428038868636, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-30B-A3B-Instruct-2507-SpecForge", + "output_throughput": 2008.7425535412629, + "accept_length": 2.91748293468006 + } + ] + } + ] + } + }, + "Qwen3-235B-A22B-Instruct-2507": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 469.12940470010284, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 633.4834448509783, + "accept_length": 2.356716526992789 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 718.620120234308, + "accept_length": 2.8762828246719394 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 469.12940470010284, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 619.3961515217887, + "accept_length": 2.5325967285309847 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 740.8090293617215, + "accept_length": 3.351527622767857 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 469.12940470010284, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 685.8224688133159, + "accept_length": 2.2254637464335056 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 718.5200251720828, + "accept_length": 2.5942242348162705 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 469.12940470010284, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 622.6877352310961, + "accept_length": 2.577754285484885 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 758.2839780669175, + "accept_length": 3.51144398279758 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 469.12940470010284, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 696.9862910262393, + "accept_length": 2.2957518385545184 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 692.54543613971, + "accept_length": 2.508131344520406 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 587.3767625807179, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 821.7716217768141, + "accept_length": 2.2131311175007076 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1165.3481778903413, + "accept_length": 3.2287879445239853 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 587.3767625807179, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 786.5291154131861, + "accept_length": 2.3811060693210626 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1263.6658286467714, + "accept_length": 4.021472447253628 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 587.3767625807179, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 729.1280796475185, + "accept_length": 2.1641727527768047 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1012.7228976076004, + "accept_length": 3.3166681444513406 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 587.3767625807179, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 801.9730196026575, + "accept_length": 2.4202165987905055 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1399.195876342606, + "accept_length": 4.477737029876627 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 587.3767625807179, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 728.5917394731794, + "accept_length": 2.180077789251727 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 966.5149174357106, + "accept_length": 3.0996346930308336 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 529.8952857212083, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 642.7287443329789, + "accept_length": 1.8722335837366109 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 814.539845630713, + "accept_length": 2.3454133346915906 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 529.8952857212083, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 617.9738942581079, + "accept_length": 1.9436368219822697 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 779.531140147999, + "accept_length": 2.571956737666924 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 529.8952857212083, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 579.7478777831109, + "accept_length": 1.879637550849381 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 684.112380410899, + "accept_length": 2.3538604252889965 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 529.8952857212083, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 607.3644823224199, + "accept_length": 1.9674055586107704 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 789.9679697718769, + "accept_length": 2.6698328935795956 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 529.8952857212083, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 596.0590450290033, + "accept_length": 1.987328547838102 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 670.0058040199536, + "accept_length": 2.329033512672587 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 553.0503522362385, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 866.1813723921825, + "accept_length": 2.533027363039563 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1068.373749600453, + "accept_length": 3.238804311590177 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 553.0503522362385, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 853.4917713020631, + "accept_length": 2.8369721532226433 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1176.5192650014792, + "accept_length": 4.083723300745958 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 553.0503522362385, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 772.1684975661775, + "accept_length": 2.5123042505592843 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1032.477913431608, + "accept_length": 3.6360244115082825 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 553.0503522362385, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 889.8951303902317, + "accept_length": 2.955997016746898 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1267.5178598410528, + "accept_length": 4.4874762125186445 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 553.0503522362385, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 736.1010265214783, + "accept_length": 2.3861131594156686 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 983.9906558013464, + "accept_length": 3.412326127536581 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 598.1832041732818, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 803.7805606947842, + "accept_length": 2.090690935434212 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1062.9796952555507, + "accept_length": 2.8172381425652917 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 598.1832041732818, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 759.6333115912107, + "accept_length": 2.2179516111790765 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1093.1979234549972, + "accept_length": 3.268498808394456 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 598.1832041732818, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 708.4447966909656, + "accept_length": 2.077364507787014 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 874.062642276262, + "accept_length": 2.6670587896561795 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 598.1832041732818, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 767.8685797664081, + "accept_length": 2.2474642743536366 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 1155.6572987907093, + "accept_length": 3.490068495285106 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 598.1832041732818, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 711.4663371023372, + "accept_length": 2.129619842542645 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 835.6105646149398, + "accept_length": 2.590646146520392 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 539.5161023038148, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 689.4282413740445, + "accept_length": 1.941237358311274 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 872.4508905377182, + "accept_length": 2.556773924332344 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 539.5161023038148, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 636.4408069963314, + "accept_length": 2.027268079304664 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 885.529748337286, + "accept_length": 2.8442245393804413 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 539.5161023038148, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 642.4958901994291, + "accept_length": 2.0553746448296777 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 730.7331843587357, + "accept_length": 2.4330876223070512 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 539.5161023038148, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 641.1037073226237, + "accept_length": 2.0361251069493296 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 889.0304393086461, + "accept_length": 2.965008914078923 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 539.5161023038148, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 654.3422430101997, + "accept_length": 2.1356956699218137 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 742.3749721046132, + "accept_length": 2.5176210584474528 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 563.1619467852893, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 716.6967887897075, + "accept_length": 2.0240035915598344 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 823.4218898853592, + "accept_length": 2.356617214868455 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 563.1619467852893, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 680.2044274358036, + "accept_length": 2.14011469258975 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 808.934577824737, + "accept_length": 2.6032639643837037 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 563.1619467852893, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 630.9312870281678, + "accept_length": 1.9776516235921864 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 698.9315763256182, + "accept_length": 2.2587729126518172 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 563.1619467852893, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 685.8554308455039, + "accept_length": 2.1591340093176212 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 826.5168292170538, + "accept_length": 2.6672259363465063 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 563.1619467852893, + "accept_length": 1.0 + }, + { + "Name": "lmsys/Qwen3-235B-A22B-EAGLE3", + "output_throughput": 636.0480501999019, + "accept_length": 2.001480647431386 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-235B-A22B-Instruct-2507-SpecForge", + "output_throughput": 683.7427107159214, + "accept_length": 2.241436629482574 + } + ] + } + ] + } + }, + "Qwen3-Next-80B-A3B-Instruct-FP8": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 549.6362180919164, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 683.8795985073891, + "accept_length": 3.13391215089175 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 549.6362180919164, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 753.237074543623, + "accept_length": 3.9038018228889597 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 549.6362180919164, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 746.7222279174218, + "accept_length": 4.022678679117706 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 549.6362180919164, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 771.153101164556, + "accept_length": 4.345554699994077 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 549.6362180919164, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 773.4012327870145, + "accept_length": 4.607604467310829 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 863.7773324206034, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1478.3001038430784, + "accept_length": 3.498551418454351 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 863.7773324206034, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1764.2064514729698, + "accept_length": 4.677160426045899 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 863.7773324206034, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1758.0166003158934, + "accept_length": 4.755809947207558 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 863.7773324206034, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1912.6838622508392, + "accept_length": 5.554967332076544 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 863.7773324206034, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1853.434631732593, + "accept_length": 5.756492370623537 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 803.4970369348379, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1095.5102974622082, + "accept_length": 2.581125058112506 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 803.4970369348379, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1157.636689246293, + "accept_length": 2.9156972910237133 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 803.4970369348379, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1197.112468072539, + "accept_length": 3.1331585165547646 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 803.4970369348379, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1127.4364940073876, + "accept_length": 3.0475279197966354 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 803.4970369348379, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1198.9417562126052, + "accept_length": 3.4190589216409535 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 788.4509521573036, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1245.6702060145312, + "accept_length": 3.4647713687985653 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 788.4509521573036, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1527.7120587214345, + "accept_length": 4.612265133111893 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 788.4509521573036, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1536.7723048769212, + "accept_length": 4.676180904522613 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 788.4509521573036, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1628.1293604862747, + "accept_length": 5.4577785667790994 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 788.4509521573036, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1629.7244930267507, + "accept_length": 5.621873496873497 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 916.0337036761792, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1463.1234977160723, + "accept_length": 3.1058026902179443 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 916.0337036761792, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1724.2207417984275, + "accept_length": 3.8462516284893944 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 916.0337036761792, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1734.4894352951553, + "accept_length": 3.9821418050654955 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 916.0337036761792, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1786.8774464735384, + "accept_length": 4.2761952310299485 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 916.0337036761792, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1829.5532782765572, + "accept_length": 4.590307145700787 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 827.3050477430119, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 986.4282909200625, + "accept_length": 2.0752097090844193 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 827.3050477430119, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 981.0983772859984, + "accept_length": 2.1801329261720857 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 827.3050477430119, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1057.6549922432027, + "accept_length": 2.439575219817722 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 827.3050477430119, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 956.6098887389447, + "accept_length": 2.2457481515800852 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 827.3050477430119, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1041.5277102267419, + "accept_length": 2.606484877248997 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 909.8620481543201, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1368.9499756838852, + "accept_length": 2.7362548025140208 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 909.8620481543201, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1457.9918429280988, + "accept_length": 3.1803662497541225 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 909.8620481543201, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1511.274616068283, + "accept_length": 3.3682366894832594 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 909.8620481543201, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1463.9444559000415, + "accept_length": 3.380290412894046 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 909.8620481543201, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Next-80B-A3B-Instruct-FP8-SpecForge", + "output_throughput": 1541.4580844550508, + "accept_length": 3.7385501251645787 + } + ] + } + ] + } + }, + "Qwen3-Coder-30B-A3B-Instruct": { + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1296.1854608851213, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2621.7139434700584, + "accept_length": 3.394971072541166 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1296.1854608851213, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2966.4459091363574, + "accept_length": 4.5011526953450725 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1296.1854608851213, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2236.868611380527, + "accept_length": 3.9489230027326796 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1296.1854608851213, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 3205.2025971977832, + "accept_length": 5.306789266712931 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1296.1854608851213, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2553.012134540716, + "accept_length": 4.221071958746777 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1506.2936922288973, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2992.02067556649, + "accept_length": 3.138553878632709 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1506.2936922288973, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 3328.9058789398114, + "accept_length": 3.9449129401751835 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1506.2936922288973, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2541.3931549111803, + "accept_length": 3.336379596827288 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1506.2936922288973, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 3472.3919294148427, + "accept_length": 4.477776008915068 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 1506.2936922288973, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-30B-A3B-Instruct-SpecForge", + "output_throughput": 2552.5518885328293, + "accept_length": 3.5865930607956185 + } + ] + } + ] + } + }, + "Qwen3-Coder-480B-A35B-Instruct": { + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 470.6571664751315, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 867.5261370310272, + "accept_length": 3.4954686382065345 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 470.6571664751315, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 1044.4475556194586, + "accept_length": 4.68614810868407 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 470.6571664751315, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 945.2207076385645, + "accept_length": 4.2835241878943675 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 470.6571664751315, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 1165.0727231905212, + "accept_length": 5.626203379024545 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 470.6571664751315, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 956.5336674844815, + "accept_length": 4.574128043621322 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.99996954994094, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 846.6405796214389, + "accept_length": 3.0936425388083757 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.99996954994094, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 946.3806786937351, + "accept_length": 3.8547162126548313 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.99996954994094, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 817.5432981932123, + "accept_length": 3.3539182909649066 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.99996954994094, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 983.2554936551461, + "accept_length": 4.260473117512835 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.99996954994094, + "accept_length": 1.0 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Qwen3-Coder-480B-A35B-Instruct-SpecForge-EigenAI", + "output_throughput": 790.2818911646486, + "accept_length": 3.379611891844464 + } + ] + } + ] + } + }, + "Kimi-K2-Instruct": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 337.92445122816076, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 498.355967400969, + "accept_length": 3.271389121751566 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 337.92445122816076, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 538.7660861191819, + "accept_length": 4.120435815920245 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 337.92445122816076, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 476.5166831456105, + "accept_length": 3.5748305647840533 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 337.92445122816076, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 544.16588655688, + "accept_length": 4.655279611582661 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 337.92445122816076, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 459.1757114935756, + "accept_length": 3.4419677544677545 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 492.06079685961566, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 877.2113745892083, + "accept_length": 3.46806357521281 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 492.06079685961566, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 995.8769550545389, + "accept_length": 4.610169876195772 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 492.06079685961566, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 772.6100737625807, + "accept_length": 3.527844083399639 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 492.06079685961566, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 1022.7285831443611, + "accept_length": 5.383128673454291 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 492.06079685961566, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 649.083231514055, + "accept_length": 3.1435862587473253 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 430.9240376244664, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 533.8166177911393, + "accept_length": 2.3897198230461343 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 430.9240376244664, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 526.1187611377575, + "accept_length": 2.738876732312181 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 430.9240376244664, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 473.3129895327435, + "accept_length": 2.394141207153502 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 430.9240376244664, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 488.46384825810924, + "accept_length": 2.7821796546219706 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 430.9240376244664, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 451.126180366313, + "accept_length": 2.536454493323503 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 466.0584238730984, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 779.7838793636296, + "accept_length": 3.364936827816644 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 466.0584238730984, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 868.550857852841, + "accept_length": 4.423030465709301 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 466.0584238730984, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 729.1217213710999, + "accept_length": 3.7321711568938194 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 466.0584238730984, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 897.9039799990946, + "accept_length": 5.162398550153652 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 466.0584238730984, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 669.271164663664, + "accept_length": 3.7044178210408085 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.12137141510016, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 841.5023790421864, + "accept_length": 3.162685632492396 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.12137141510016, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 904.3910288246204, + "accept_length": 3.943605886942718 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.12137141510016, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 716.7319007181034, + "accept_length": 3.1374681580049573 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.12137141510016, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 896.7006322822839, + "accept_length": 4.400262176061309 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 500.12137141510016, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 650.4333056536461, + "accept_length": 3.0780193205478037 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 433.44658979995484, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 647.3644717982133, + "accept_length": 2.9848269628099175 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 433.44658979995484, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 660.0254297132984, + "accept_length": 3.594056395834917 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 433.44658979995484, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 523.0340443308603, + "accept_length": 2.8796471741261027 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 433.44658979995484, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 630.5425124127137, + "accept_length": 3.944647875329984 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 433.44658979995484, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 389.47080223360666, + "accept_length": 2.5096594789735582 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 505.3742994094499, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 783.436424568974, + "accept_length": 2.904452196823693 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 505.3742994094499, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 811.3642458480507, + "accept_length": 3.4622853609057755 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 505.3742994094499, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 699.8111934038128, + "accept_length": 3.0198274205132876 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 505.3742994094499, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 770.4892578818251, + "accept_length": 3.6995331477421103 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 505.3742994094499, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Kimi-K2-Instruct-eagle3", + "output_throughput": 596.3162033813331, + "accept_length": 2.7901899604967983 + } + ] + } + ] + } + }, + "Ling-flash-2.0": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 674.3464018618124, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1144.7606179148752, + "accept_length": 3.4351661916604646 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 674.3464018618124, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1253.4000030615975, + "accept_length": 4.487906489549112 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 674.3464018618124, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1059.7381115819003, + "accept_length": 3.331830155824441 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 674.3464018618124, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1323.0093663978187, + "accept_length": 5.148644964283767 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 674.3464018618124, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1026.8025294413142, + "accept_length": 3.126593214481735 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 762.7113399535667, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1434.6065070935829, + "accept_length": 3.4340471141971713 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 762.7113399535667, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1607.3212268988339, + "accept_length": 4.493397164127635 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 762.7113399535667, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1383.6720582197756, + "accept_length": 3.7931376508179415 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 762.7113399535667, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1685.5692612687462, + "accept_length": 5.218245374511558 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 762.7113399535667, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1330.1086623703009, + "accept_length": 3.793696144088135 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 728.5278345617202, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1022.5890920470158, + "accept_length": 2.392568385378843 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 728.5278345617202, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 990.0430932236113, + "accept_length": 2.648161574313827 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 728.5278345617202, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 914.3899001110539, + "accept_length": 2.5161251562049407 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 728.5278345617202, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 942.3914903299366, + "accept_length": 2.771332137960131 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 728.5278345617202, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 968.0479918450316, + "accept_length": 2.8558805412179527 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 740.2477168580639, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1271.2889448808319, + "accept_length": 3.1471241394625804 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 740.2477168580639, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1353.1437889143726, + "accept_length": 3.9318483282257697 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 740.2477168580639, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1175.4192382338058, + "accept_length": 3.29687986547923 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 740.2477168580639, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1358.9726439538854, + "accept_length": 4.370163501574083 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 740.2477168580639, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1141.7913416362687, + "accept_length": 3.3590013964490297 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 770.3957537752161, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1305.1833791876973, + "accept_length": 2.9790301516097895 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 770.3957537752161, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1366.417326281792, + "accept_length": 3.6103649876590875 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 770.3957537752161, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1130.7868943433502, + "accept_length": 2.8933133857317164 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 770.3957537752161, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1345.6741018953574, + "accept_length": 3.9330923185867093 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 770.3957537752161, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1061.6897228931932, + "accept_length": 2.902182106883942 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 747.7098566179897, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 863.8565336005082, + "accept_length": 1.907102314310342 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 747.7098566179897, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 833.1235940586521, + "accept_length": 2.047546254809973 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 747.7098566179897, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 798.9811798480557, + "accept_length": 1.9372590117256243 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 747.7098566179897, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 763.2761511276084, + "accept_length": 2.0470985454359427 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 747.7098566179897, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 779.3060665006524, + "accept_length": 2.045476819601249 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 794.1289733679167, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1185.7250147683403, + "accept_length": 2.562389392369937 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 794.1289733679167, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1161.8732670284553, + "accept_length": 2.886871902842324 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 794.1289733679167, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1052.640023467198, + "accept_length": 2.6017604302340236 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 794.1289733679167, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1111.996259596397, + "accept_length": 3.0648124985786733 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 794.1289733679167, + "accept_length": 1.0 + }, + { + "Name": "AQ-MedAI/Ling-Flash-2.0-eagle3", + "output_throughput": 1004.4992021266573, + "accept_length": 2.6709053367549105 + } + ] + } + ] + } + }, + "Llama-3.1-8B-Instruct": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 181.81151788749455, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 228.64232714994796, + "accept_length": 1.7165139181419709 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 321.2528041157779, + "accept_length": 2.5481878001819607 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 181.81151788749455, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 213.550264904667, + "accept_length": 1.7634936642258956 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 329.6873220645443, + "accept_length": 2.8537845395516377 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 181.81151788749455, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 195.13619448514442, + "accept_length": 1.7528912619638426 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 251.43922505539766, + "accept_length": 2.2820562939796716 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 181.81151788749455, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 197.901650893672, + "accept_length": 1.7742552127753433 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 317.61058794222197, + "accept_length": 2.9733251079580505 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 181.81151788749455, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 182.0257072155964, + "accept_length": 1.789228234172427 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 240.85801894998306, + "accept_length": 2.367398432594591 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 191.04076784280642, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 399.2995452070592, + "accept_length": 2.7825411590459592 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 492.28246574028134, + "accept_length": 3.4786948176583494 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 191.04076784280642, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 422.40466722576286, + "accept_length": 3.254684892147128 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 594.5033645961273, + "accept_length": 4.624857400180126 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 191.04076784280642, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 387.0489467031037, + "accept_length": 3.3070174292508296 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 480.43534296060534, + "accept_length": 4.116159164796923 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 191.04076784280642, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 413.57783551553456, + "accept_length": 3.489213277012106 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 638.0439777096752, + "accept_length": 5.402844266750837 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 191.04076784280642, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 326.8790406711244, + "accept_length": 3.072066504990206 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 453.306808098541, + "accept_length": 4.25573095185686 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.98120707576373, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 414.90616666264776, + "accept_length": 2.930670028119849 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 404.24667749722187, + "accept_length": 2.8980726819445777 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.98120707576373, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 453.73692243041774, + "accept_length": 3.554148008484563 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 446.6366476858434, + "accept_length": 3.5164393144456105 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.98120707576373, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 338.6308027570883, + "accept_length": 2.9393909722902185 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 346.46724606666106, + "accept_length": 3.0061221366256823 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.98120707576373, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 454.730035166582, + "accept_length": 3.906676145543851 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 450.03198538047087, + "accept_length": 3.855839765261211 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.98120707576373, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 305.1648971387325, + "accept_length": 2.9089536379397125 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 308.00561770283963, + "accept_length": 2.938163437236731 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.91017930680567, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 432.8677712430711, + "accept_length": 3.0469174293472796 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 465.1765542307934, + "accept_length": 3.3398192040568846 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.91017930680567, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 479.1212006261437, + "accept_length": 3.7445769729930163 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 548.9370103875078, + "accept_length": 4.318366474235621 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.91017930680567, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 340.2704451839945, + "accept_length": 2.9425913908717285 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 377.47349118830954, + "accept_length": 3.2519286521546853 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.91017930680567, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 480.3152659024827, + "accept_length": 4.0959237477185155 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 571.4886457684788, + "accept_length": 4.910129659643436 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.91017930680567, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 311.1051926955927, + "accept_length": 2.9338537387017256 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 330.15665770360005, + "accept_length": 3.126203604641593 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.70410640395912, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 380.6915537026263, + "accept_length": 2.6893540748536475 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 439.67672671912396, + "accept_length": 3.16861704188786 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.70410640395912, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 398.3738662742165, + "accept_length": 3.1199565043209523 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 506.22686693578754, + "accept_length": 3.9957244075250427 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.70410640395912, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 322.29847741557273, + "accept_length": 2.771756050751679 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 375.34956052924895, + "accept_length": 3.236171472299629 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.70410640395912, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 391.25705242634194, + "accept_length": 3.334862665932587 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 516.904537338255, + "accept_length": 4.466856034741759 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 189.70410640395912, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 287.68205157705233, + "accept_length": 2.7148899046029547 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 378.8468257829908, + "accept_length": 3.585376494197714 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 185.6534194378935, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 237.18050733350836, + "accept_length": 1.713236561734993 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 258.6437346257605, + "accept_length": 1.9050339301460721 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 185.6534194378935, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 226.67848476067016, + "accept_length": 1.8075300109130592 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 254.48969338840087, + "accept_length": 2.043805528134255 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 185.6534194378935, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 210.94791438286492, + "accept_length": 1.8654798891594593 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 251.07710462288492, + "accept_length": 2.2264818220398923 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 185.6534194378935, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 211.18454065719607, + "accept_length": 1.8434056761268782 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 240.6034453504167, + "accept_length": 2.1029710512950737 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 185.6534194378935, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 183.72672690273865, + "accept_length": 1.7817737292479987 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 229.82170237350869, + "accept_length": 2.250341575212658 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 1, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.4500188461883, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 409.86415544506445, + "accept_length": 2.8552892726009724 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 442.54523731909666, + "accept_length": 3.135712400558006 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.4500188461883, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 438.0519648397228, + "accept_length": 3.3792158666871135 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 507.1290934019136, + "accept_length": 3.936040126357265 + } + ] + }, + { + "batch_size": 1, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.4500188461883, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 352.1689105895484, + "accept_length": 3.026258098612226 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 413.1686528229548, + "accept_length": 3.5475168823860437 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.4500188461883, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 434.1788724748705, + "accept_length": 3.6819800875461333 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 514.2312383540044, + "accept_length": 4.357665531437638 + } + ] + }, + { + "batch_size": 1, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 190.4500188461883, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", + "output_throughput": 311.5910755177637, + "accept_length": 2.9283727399165507 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.1-8B-Instruct-SpecForge", + "output_throughput": 390.64506651929287, + "accept_length": 3.692280754414928 + } + ] + } + ] + } + }, + "Llama-3.3-70B-Instruct": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 453.2156138501392, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 521.4502791575164, + "accept_length": 1.2760798037239203 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Spec for ge", + "output_throughput": 837.9426300003847, + "accept_length": 2.3179247901200304 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 453.2156138501392, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 500.5534332009228, + "accept_length": 1.2836005168205962 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 855.6400225608106, + "accept_length": 2.4851382017038057 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 453.2156138501392, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 500.33326156436937, + "accept_length": 1.3482255389718076 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 758.9001336688345, + "accept_length": 2.12511673151751 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 453.2156138501392, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 483.12653680688, + "accept_length": 1.2856745693167546 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 820.5175400063332, + "accept_length": 2.516910489405022 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 453.2156138501392, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 480.4218686725539, + "accept_length": 1.3936331604189096 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 739.405741336959, + "accept_length": 2.222061210294459 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 567.3739460148672, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1088.844896763402, + "accept_length": 2.3720131878590123 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1273.7733416283656, + "accept_length": 2.841736535013628 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 567.3739460148672, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1122.2476729474943, + "accept_length": 2.5920045204124875 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1382.9357431087456, + "accept_length": 3.243898689873717 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 567.3739460148672, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1112.8479569335152, + "accept_length": 2.792588962605549 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1274.2110431983278, + "accept_length": 3.2416170775479363 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 567.3739460148672, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1079.9951811356827, + "accept_length": 2.6718376973892366 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1327.6044700788502, + "accept_length": 3.3766338373668217 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 567.3739460148672, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1090.3170854344964, + "accept_length": 2.966812280063099 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1215.8347875575441, + "accept_length": 3.3641021480547684 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 540.4640557255416, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1234.647877556777, + "accept_length": 2.9232673267326734 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1238.4736758319698, + "accept_length": 2.9606951984177083 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 540.4640557255416, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1377.8052334866013, + "accept_length": 3.5324281309061973 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1409.5100765643524, + "accept_length": 3.6175162329362442 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 540.4640557255416, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1129.6661036217977, + "accept_length": 3.143848893296669 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1108.3072501756835, + "accept_length": 3.2248797608215263 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 540.4640557255416, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1425.2993761886291, + "accept_length": 3.8789368991048736 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1440.3671955624673, + "accept_length": 3.97791186891054 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 540.4640557255416, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1069.4986663607351, + "accept_length": 3.1943331425300516 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1033.773238205561, + "accept_length": 3.2422141262192974 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.9500728009846, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1194.0875984832494, + "accept_length": 2.6663626344392504 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1290.1122375104421, + "accept_length": 2.925804965875309 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.9500728009846, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1282.7936401185236, + "accept_length": 3.0671719811813904 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1426.372333907719, + "accept_length": 3.436568804650481 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.9500728009846, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1090.1088508973057, + "accept_length": 2.8127895941495002 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1174.0867819009864, + "accept_length": 3.0611013660766493 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.9500728009846, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1267.8737053510965, + "accept_length": 3.1906793120660706 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1407.8140138598972, + "accept_length": 3.6735002608242047 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.9500728009846, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1013.2705272855593, + "accept_length": 2.7776112847805305 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 968.2027451202639, + "accept_length": 2.742653690956563 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.8834615148919, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1210.6010917932015, + "accept_length": 2.723797958423008 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1295.014267720614, + "accept_length": 2.952023346303502 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.8834615148919, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1303.4195570335166, + "accept_length": 3.133414966360772 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1423.2736941362525, + "accept_length": 3.4980468448438247 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.8834615148919, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1070.711661408102, + "accept_length": 2.735034762087001 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1154.785652335772, + "accept_length": 2.9811645516106386 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.8834615148919, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1279.5345355421975, + "accept_length": 3.284394784770605 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1399.3991191944933, + "accept_length": 3.716324359708698 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 560.8834615148919, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 1013.3765756840332, + "accept_length": 2.773990564681233 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1035.4140338795994, + "accept_length": 2.933293078243183 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 512.5751663875466, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 704.0737829344649, + "accept_length": 1.645732050137249 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 936.4940018423655, + "accept_length": 2.2541347317466722 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 512.5751663875466, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 684.0195321200449, + "accept_length": 1.702027072988232 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 933.0572305312112, + "accept_length": 2.39442380929992 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 512.5751663875466, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 618.4946534541955, + "accept_length": 1.7860533893688224 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 700.886442439991, + "accept_length": 2.281622206910129 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 512.5751663875466, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 652.1412786559076, + "accept_length": 1.7116903633491312 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 887.7001871678323, + "accept_length": 2.452738257649581 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 512.5751663875466, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 635.2599880909434, + "accept_length": 1.9610333607746286 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 854.0347909075315, + "accept_length": 2.589833798374378 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 575.6879373469175, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 962.5545831639148, + "accept_length": 2.0451300999292217 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1020.0538308626681, + "accept_length": 2.1911976817371235 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 575.6879373469175, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 963.8356757692138, + "accept_length": 2.1687507495755036 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 1039.643962895085, + "accept_length": 2.3552079123829617 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 575.6879373469175, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 890.1003387342033, + "accept_length": 2.226321240698847 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 960.5616523564485, + "accept_length": 2.4811411267352264 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 575.6879373469175, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 916.6826693888017, + "accept_length": 2.1849745643049188 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 984.4877550429275, + "accept_length": 2.4152394292465176 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 575.6879373469175, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-LLaMA3.3-Instruct-70B", + "output_throughput": 838.0962787179271, + "accept_length": 2.3145643059121785 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-Specforge", + "output_throughput": 924.0808096194634, + "accept_length": 2.573260793115575 + } + ] + } + ] + } + }, + "Llama-4-Scout-17B-16E-Instruct": { + "gsm8k": { + "benchmark_name": "gsm8k", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 455.9311905316165, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 816.6176343207234, + "accept_length": 2.435108707729916 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 908.8655650704263, + "accept_length": 3.1118742007294085 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 455.9311905316165, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 806.5328373116205, + "accept_length": 2.6234459324405357 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 971.8534490877095, + "accept_length": 3.8715801886792454 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 455.9311905316165, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 708.8133468064259, + "accept_length": 2.146746247607535 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 818.3072714693558, + "accept_length": 2.918526679710503 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 455.9311905316165, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 765.9810114809961, + "accept_length": 2.675257522087863 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 957.227019602509, + "accept_length": 4.307217442700466 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 455.9311905316165, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 675.0775309782273, + "accept_length": 2.144316290813106 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 814.5839518607636, + "accept_length": 2.627502101582583 + } + ] + } + ] + }, + "math500": { + "benchmark_name": "math500", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 561.835811548351, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1478.9989946720648, + "accept_length": 2.366719134681358 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1884.3462895109676, + "accept_length": 3.238557789111507 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 561.835811548351, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1447.5513200323323, + "accept_length": 2.5898901840327406 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 2100.7682204066577, + "accept_length": 4.153214423200308 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 561.835811548351, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1199.1485073659853, + "accept_length": 2.489558557182447 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1457.2169829849418, + "accept_length": 3.2046972238757507 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 561.835811548351, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1330.0337890073868, + "accept_length": 2.648556845221877 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 2110.3314050998847, + "accept_length": 4.7805795395081105 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 561.835811548351, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1153.7706965189202, + "accept_length": 2.6314392278632304 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1369.6607164745208, + "accept_length": 3.2076523352436657 + } + ] + } + ] + }, + "mtbench": { + "benchmark_name": "mtbench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 502.10114738381606, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1252.9681990096112, + "accept_length": 2.3541095408844828 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1302.3829223511154, + "accept_length": 2.4913843888070693 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 502.10114738381606, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1225.4607594389363, + "accept_length": 2.5648559607722956 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1312.399917450856, + "accept_length": 2.836414637256152 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 502.10114738381606, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 953.148992300308, + "accept_length": 2.222710749523974 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 967.1281111811169, + "accept_length": 2.3256101583113455 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 502.10114738381606, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1157.0433602013916, + "accept_length": 2.649528603387664 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1276.9552963643773, + "accept_length": 3.0189181867437243 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 502.10114738381606, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 940.9893388280037, + "accept_length": 2.3959043407227965 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1010.4098410869198, + "accept_length": 2.7008052625609618 + } + ] + } + ] + }, + "humaneval": { + "benchmark_name": "humaneval", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 631.8746804703884, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1515.800628974162, + "accept_length": 2.664927494512612 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1749.0012751674196, + "accept_length": 3.224152798137449 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 631.8746804703884, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1556.515161340629, + "accept_length": 3.085438335809807 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1921.2922045342316, + "accept_length": 4.140846637369973 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 631.8746804703884, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1201.849883743592, + "accept_length": 2.6006220481511346 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1393.1592557980014, + "accept_length": 3.1744799971652315 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 631.8746804703884, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1456.346786965349, + "accept_length": 3.2582381225462083 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1944.8214954525663, + "accept_length": 4.7947306331104995 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 631.8746804703884, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1109.058302621911, + "accept_length": 2.6508010386556267 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1234.7042057027743, + "accept_length": 3.0442784990549376 + } + ] + } + ] + }, + "livecodebench": { + "benchmark_name": "livecodebench", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 484.2501137181978, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1598.2921930690502, + "accept_length": 2.487202280374381 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1933.9962764283844, + "accept_length": 3.14740116583215 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 484.2501137181978, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1601.2688464385185, + "accept_length": 2.8043640587405627 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 2144.3319751584095, + "accept_length": 3.983057732747085 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 484.2501137181978, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1051.7266219288254, + "accept_length": 2.1138485934104656 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1320.656674087923, + "accept_length": 2.7145795398417976 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 484.2501137181978, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1501.558947290443, + "accept_length": 2.929916684169992 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 2170.188140733029, + "accept_length": 4.55060712303548 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 484.2501137181978, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1009.5574686537159, + "accept_length": 2.2590065740745002 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1249.8114756626915, + "accept_length": 2.8130523194007555 + } + ] + } + ] + }, + "financeqa": { + "benchmark_name": "financeqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 288.9007335547823, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1022.713052476267, + "accept_length": 1.7952034022379475 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1189.61672405822, + "accept_length": 2.2164571332464367 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 288.9007335547823, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 963.8209003406079, + "accept_length": 1.8240590609583607 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1171.8275957081507, + "accept_length": 2.408275220827522 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 288.9007335547823, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 755.8055387643059, + "accept_length": 1.780077619663648 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 887.65933899505, + "accept_length": 2.1907344347752975 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 288.9007335547823, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 885.0003924094965, + "accept_length": 1.864155494076754 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1084.5573704005851, + "accept_length": 2.459442783236034 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 288.9007335547823, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 773.7660016870891, + "accept_length": 2.05643096671835 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 838.3207906571789, + "accept_length": 2.1910908349096845 + } + ] + } + ] + }, + "gpqa": { + "benchmark_name": "gpqa", + "results": [ + { + "batch_size": 8, + "steps": 3, + "topk": 1, + "num_draft_tokens": 4, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 541.0010469896803, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1320.0198779778916, + "accept_length": 2.0166714112874526 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1482.2781495871964, + "accept_length": 2.3200242800296755 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 1, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 541.0010469896803, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1258.0775283103167, + "accept_length": 2.135039169677331 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1468.3432054658438, + "accept_length": 2.5528455284552845 + } + ] + }, + { + "batch_size": 8, + "steps": 5, + "topk": 3, + "num_draft_tokens": 6, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 541.0010469896803, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1405.110892125768, + "accept_length": 2.8834021014937705 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1502.213627081269, + "accept_length": 3.0623772161357583 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 1, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 541.0010469896803, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1148.5409144989237, + "accept_length": 2.1684843736177633 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1379.1223204247422, + "accept_length": 2.672381928590287 + } + ] + }, + { + "batch_size": 8, + "steps": 7, + "topk": 4, + "num_draft_tokens": 8, + "metrics": [ + { + "Name": "Wihtout EAGLE3", + "output_throughput": 541.0010469896803, + "accept_length": 1.0 + }, + { + "Name": "lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1", + "output_throughput": 1345.7377508882935, + "accept_length": 3.044341630328194 + }, + { + "Name": "lmsys/SGLang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-SpecForge", + "output_throughput": 1474.1967930541948, + "accept_length": 3.315005686664771 + } + ] + } + ] + } + } +} diff --git a/progress/github/SpecForge/docs/spec_bundle/src/components/BenchmarkDashboard.vue b/progress/github/SpecForge/docs/spec_bundle/src/components/BenchmarkDashboard.vue new file mode 100644 index 0000000000000000000000000000000000000000..a5d33cc912211e9311f78422b0436a3e605a0dbd --- /dev/null +++ b/progress/github/SpecForge/docs/spec_bundle/src/components/BenchmarkDashboard.vue @@ -0,0 +1,601 @@ + + + + + diff --git a/progress/github/SpecForge/docs/spec_bundle/vite.config.js b/progress/github/SpecForge/docs/spec_bundle/vite.config.js new file mode 100644 index 0000000000000000000000000000000000000000..d747468c3295796728aabd7aae67de54928095c6 --- /dev/null +++ b/progress/github/SpecForge/docs/spec_bundle/vite.config.js @@ -0,0 +1,23 @@ +import { defineConfig } from 'vite' +import vue from '@vitejs/plugin-vue' + +// https://vite.dev/config/ +export default defineConfig({ + plugins: [vue()], + base: './', // Use relative paths for deployment + build: { + outDir: 'dist', + assetsDir: 'assets', + sourcemap: false, + minify: 'esbuild', // Use esbuild for faster minification (Vite built-in) + rollupOptions: { + output: { + manualChunks: { + 'vue-vendor': ['vue'], + 'echarts-vendor': ['echarts', 'vue-echarts'], + 'csv-vendor': ['papaparse'] + } + } + } + } +}) diff --git a/progress/github/SpecForge/tests/ci/gpu_lock_exec.py b/progress/github/SpecForge/tests/ci/gpu_lock_exec.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca44c6b66c73ad26be2eac0505626857aec64a4 --- /dev/null +++ b/progress/github/SpecForge/tests/ci/gpu_lock_exec.py @@ -0,0 +1,249 @@ +import argparse +import fcntl +import os +import random +import sys +import time +from typing import List + +SLEEP_BACKOFF = 5.0 + + +def main(): + """ + Remark: Can use `lslocks` to debug + """ + args = _parse_args() + + if args.print_only: + _execute_print_only(args) + return + + fd_locks = _try_acquire(args) + + dev_list = ",".join(str(x.gpu_id) for x in fd_locks) + os.environ["CUDA_VISIBLE_DEVICES"] = dev_list + + if args.env: + for env_var in args.env: + name, value = env_var.split("=") + os.environ[name] = value + print( + f"[gpu_lock_exec] Setting environment variable: {name}={value}", + flush=True, + ) + print(f"[gpu_lock_exec] Acquired GPUs: {dev_list}", flush=True) + + _os_execvp(args) + + +def _os_execvp(args): + cmd = args.cmd + if cmd[0] == "--": + cmd = cmd[1:] + + # propagate the environment variables + os.execvp(cmd[0], cmd) + + +def _parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--count", type=int, default=None, help="Acquire this many GPUs (any free ones)" + ) + p.add_argument( + "--devices", + type=str, + default=None, + help="Comma separated explicit devices to acquire (e.g. 0,1)", + ) + p.add_argument( + "--total-gpus", type=int, default=8, help="Total GPUs on the machine" + ) + p.add_argument( + "--timeout", + type=int, + default=3600, + help="Seconds to wait for locks before failing", + ) + p.add_argument( + "--env", + type=str, + default=None, + nargs="*", + help="Environment variables to set (e.g. HF_TOKEN=1234567890)", + ) + p.add_argument( + "--lock-path-pattern", + type=str, + default="/dev/shm/custom_gpu_lock_{gpu_id}.lock", + help='Filename pattern with "{gpu_id}" placeholder', + ) + p.add_argument( + "--print-only", + action="store_true", + help="Probe free devices and print them (does NOT hold locks)", + ) + p.add_argument( + "cmd", + nargs=argparse.REMAINDER, + help="Command to exec after '--' (required unless --print-only)", + ) + args = p.parse_args() + + if "{gpu_id}" not in args.lock_path_pattern: + raise Exception("ERROR: --lock-path-pattern must contain '{i}' placeholder.") + + if not args.cmd and not args.print_only: + raise Exception("ERROR: missing command to run. Use -- before command.") + + return args + + +def _execute_print_only(args): + free = [] + _ensure_lock_files(path_pattern=args.lock_path_pattern, total_gpus=args.total_gpus) + for i in range(args.total_gpus): + try: + fd_lock = FdLock(args.lock_path_pattern, i) + fd_lock.open() + try: + fd_lock.lock() + fcntl.flock(fd_lock.fd, fcntl.LOCK_UN) + free.append(i) + except BlockingIOError: + pass + fd_lock.close() + except Exception as e: + print( + f"Warning: Error while probing lock: {e}", file=sys.stderr, flush=True + ) + + print("Free GPUs:", ",".join(str(x) for x in free), flush=True) + + +def _try_acquire(args): + if args.devices: + devs = _parse_devices(args.devices) + return _try_acquire_specific(devs, args.lock_path_pattern, args.timeout) + else: + return _try_acquire_count( + args.count, args.total_gpus, args.lock_path_pattern, args.timeout + ) + + +def _try_acquire_specific(devs: List[int], path_pattern: str, timeout: int): + fd_locks = [] + start = time.time() + try: + _ensure_lock_files(path_pattern, max(devs) + 1) + for gpu_id in devs: + fd_lock = FdLock(path_pattern, gpu_id=gpu_id) + fd_lock.open() + while True: + try: + fd_lock.lock() + break + except BlockingIOError: + if time.time() - start > timeout: + raise TimeoutError(f"Timeout while waiting for GPU {gpu_id}") + time.sleep(SLEEP_BACKOFF * random.random()) + fd_locks.append(fd_lock) + return fd_locks + except Exception as e: + print( + f"Error during specific GPU acquisition: {e}", file=sys.stderr, flush=True + ) + for fd_lock in fd_locks: + fd_lock.close() + raise + + +def _try_acquire_count(count: int, total_gpus: int, path_pattern: str, timeout: int): + start = time.time() + _ensure_lock_files(path_pattern, total_gpus) + while True: + fd_locks: List = [] + for gpu_id in range(total_gpus): + fd_lock = FdLock(path_pattern, gpu_id=gpu_id) + fd_lock.open() + try: + fd_lock.lock() + except BlockingIOError: + fd_lock.close() + continue + + fd_locks.append(fd_lock) + if len(fd_locks) == count: + return fd_locks + + gotten_gpu_ids = [x.gpu_id for x in fd_locks] + for fd_lock in fd_locks: + fd_lock.close() + del fd_lock + + if time.time() - start > timeout: + raise TimeoutError(f"Timeout acquiring {count} GPUs (out of {total_gpus})") + + print( + f"[gpu_lock_exec] try_acquire_count failed, sleep and retry (only got: {gotten_gpu_ids})", + flush=True, + ) + time.sleep(SLEEP_BACKOFF * random.random()) + + +class FdLock: + def __init__(self, path_pattern, gpu_id: int): + self.gpu_id = gpu_id + self.path = _get_lock_path(path_pattern, self.gpu_id) + self.fd = None + + def open(self): + assert self.fd is None + self.fd = open(self.path, "a+") + # try to avoid lock disappear when execvp + os.set_inheritable(self.fd.fileno(), True) + + def lock(self): + assert self.fd is not None + fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + + def close(self): + assert self.fd is not None + try: + self.fd.close() + except Exception as e: + print( + f"Warning: Failed to close file descriptor: {e}", + file=sys.stderr, + flush=True, + ) + self.fd = None + + +def _ensure_lock_files(path_pattern: str, total_gpus: int): + lock_dir = os.path.dirname(path_pattern) + if lock_dir: + os.makedirs(lock_dir, exist_ok=True) + for gpu_id in range(total_gpus): + p = _get_lock_path(path_pattern, gpu_id) + try: + open(p, "a").close() + except Exception as e: + print( + f"Warning: Could not create lock file {p}: {e}", + file=sys.stderr, + flush=True, + ) + + +def _get_lock_path(path_pattern: str, gpu_id: int) -> str: + return path_pattern.format(gpu_id=gpu_id) + + +def _parse_devices(s: str) -> List[int]: + return [int(x) for x in s.split(",") if x.strip() != ""] + + +if __name__ == "__main__": + main() diff --git a/progress/github/SpecForge/tests/test_data/__init__.py b/progress/github/SpecForge/tests/test_data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_data/test_parsers.py b/progress/github/SpecForge/tests/test_data/test_parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..064e1587bb9f944649c91fac55b6268277cbb90d --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_parsers.py @@ -0,0 +1,204 @@ +import json +import os +import unittest +from typing import Any, Dict, List, Optional + +from transformers import AutoTokenizer + +from specforge.data.preprocessing import preprocess_conversations +from specforge.data.template import TEMPLATE_REGISTRY + + +class TestTemplatePreprocessing(unittest.TestCase): + # Configuration section + SAVE_REFERENCE = False + REF_DIR = os.path.join(os.path.dirname(__file__), "test_references") + + @classmethod + def setUpClass(cls): + """Initialize standard test data""" + cls.max_length = 65535 + if not os.path.exists(cls.REF_DIR): + os.makedirs(cls.REF_DIR) + + # 1. General model test data (Qwen, DeepSeek, etc.) + cls.standard_messages = [ + [ + {"role": "user", "content": "Who are you?"}, + {"role": "assistant", "content": "My name is Qwen2."}, + {"role": "user", "content": "How old are you?"}, + {"role": "assistant", "content": "11 years old."}, + ] + ] + + # 2. GPT-OSS Dedicated Test Data (Including Analysis and Final Channel) + cls.gpt_oss_messages = [ + [ + {"role": "user", "content": "Explain Quantum Physics."}, + { + "role": "assistant_analysis", + "content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.", + }, + { + "role": "assistant_final", + "content": "Quantum physics is the study of matter and energy at the most fundamental level...", + }, + {"role": "user", "content": "Explain Quantum Physics."}, + {"role": "assistant_final", "content": "I'm Qwen"}, + ] + ] + + # 3. Tool-Use Test Data + cls.tool_use_messages = [ + [ + { + "role": "user", + "content": "What's the weather like in Beijing today?", + }, + { + "role": "assistant", + "content": "I'll check the current weather in Beijing for you.", + }, + { + "role": "tool", + "content": '{"location": "Beijing", "temperature": 22, "condition": "Sunny"}', + }, + { + "role": "assistant", + "content": "The current weather in Beijing is sunny with a temperature of 22°C.", + }, + { + "role": "tool", + "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}', + }, + { + "role": "tool", + "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}', + }, + { + "role": "user", + "content": "Great! Can you also tell me if it will rain tomorrow?", + }, + { + "role": "assistant", + "content": "Based on the forecast, there will be no rain tomorrow—expect clear skies all day.", + }, + ] + ] + + def _get_ref_path(self, template_key: str, message_label: str = "standard"): + return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json") + + def _run_template_test( + self, + model_name: str, + template_key: str, + messages: Optional[List[List[Dict[str, str]]]] = None, + ): + """Encapsulate common test and regression validation logic""" + + # Use the input message or the default standard message. + target_messages = messages if messages is not None else self.standard_messages + message_label = None + if target_messages == self.standard_messages: + message_label = "standard" + elif target_messages == self.gpt_oss_messages: + message_label = "gpt-oss" + elif target_messages == self.tool_use_messages: + message_label = "tool-use" + else: + raise ValueError("Invalid message set") + print(f"\n>>> Running: {template_key} ({model_name}) {message_label}") + + # 1. Initialize tokenizer and template + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + chat_template = TEMPLATE_REGISTRY.get(template_key) + + # 2. Preprocess conversations + res = preprocess_conversations( + tokenizer, target_messages, chat_template, self.max_length + ) + # Extract current result + current_data = { + "input_ids": res["input_ids"][0][0].tolist(), + "loss_mask": res["loss_mask"][0][0].tolist(), + } + + ref_path = self._get_ref_path(template_key, message_label) + # 3. Branch logic: update reference or perform comparison + if self.SAVE_REFERENCE: + with open(ref_path, "w", encoding="utf-8") as f: + json.dump(current_data, f) + print(f" [INFO] Reference saved to {ref_path}") + else: + if not os.path.exists(ref_path): + self.fail( + f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True." + ) + + with open(ref_path, "r", encoding="utf-8") as f: + ref_data = json.load(f) + + self.assertListEqual(current_data["input_ids"], ref_data["input_ids"]) + self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"]) + print(f" [PASS] Regression test passed for {template_key}") + + # 4. Debug output + self.debug_show_loss_mask(res, tokenizer) + + @staticmethod + def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer): + input_ids = res["input_ids"][0][0].tolist() + loss_mask = res["loss_mask"][0][0].tolist() + RED, RESET = "\033[91m", "\033[0m" + print("-" * 30) + for tid, m in zip(input_ids, loss_mask): + txt = tokenizer.decode([tid]) + txt = txt.replace("\n", "\\n") + print(f"{RED if m == 1 else ''}{txt}{RESET}", end="") + print("\n" + "-" * 30) + + ## The Following are tests. Each test corresponds to a specific model and template. + + def test_deepseek(self): + self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3") + + def test_deepseek_v32(self): + self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32") + + def test_qwen3_thinking(self): + self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-thinking") + + def test_qwen3_instruct(self): + self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct") + + def test_qwen3_next_instruct(self): + self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen") + + def test_kimi_k2_thinking(self): + self._run_template_test("moonshotai/Kimi-K2-Thinking", "kimi-k2-thinking") + + def test_kimi_k2_instruct(self): + self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct") + + def test_qwen3_next_thinking(self): + self._run_template_test( + "Qwen/Qwen3-Next-80B-A3B-Thinking", "qwen3-next-thinking" + ) + + def test_gpt_oss(self): + self._run_template_test( + "openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages + ) + + def test_ling_flash_2_0(self): + self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0") + + def test_qwen3_instruct_with_tools(self): + self._run_template_test( + "Qwen/Qwen3-0.6B", "qwen3-instruct", messages=self.tool_use_messages + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/progress/github/SpecForge/tests/test_data/test_preprocessing.py b/progress/github/SpecForge/tests/test_data/test_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..5301aaf9bd147b63f00b76148ac0009865f7f916 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_preprocessing.py @@ -0,0 +1,354 @@ +import unittest + +import torch +from transformers import AutoTokenizer + +from specforge.data.preprocessing import preprocess_conversations +from specforge.data.template import TEMPLATE_REGISTRY + + +# Utility function for visual debugging +def visualize_loss_mask(tokenizer, input_ids, loss_mask): + """Utility function to visualize which tokens contribute to loss.""" + RED = "\033[91m" # Non-assistant tokens (loss_mask = 0) + GREEN = "\033[92m" # Assistant tokens (loss_mask = 1) + RESET = "\033[0m" + + print("\nLoss Mask Visualization:") + print("RED = Non-assistant tokens (loss_mask = 0)") + print("GREEN = Assistant tokens (loss_mask = 1)") + print("-" * 50) + + # Handle both 1D and 2D tensors - flatten if needed + if input_ids.dim() > 1: + input_ids = input_ids.flatten() + if loss_mask.dim() > 1: + loss_mask = loss_mask.flatten() + + if len(input_ids) == 0 or len(loss_mask) == 0: + print("Empty input") + return + + current_mask = loss_mask[0].item() + current_ids = [] + + for i in range(len(input_ids)): + if current_mask == loss_mask[i].item(): + current_ids.append(input_ids[i].item()) + else: + if hasattr(tokenizer, "decode"): + decoded_text = tokenizer.decode(current_ids, skip_special_tokens=False) + else: + decoded_text = " ".join([f"token_{id}" for id in current_ids]) + if current_mask == 0: + print(f"{RED}{decoded_text}{RESET}", end="") + else: + print(f"{GREEN}{decoded_text}{RESET}", end="") + current_ids = [input_ids[i].item()] + current_mask = loss_mask[i].item() + + # Print remaining tokens + if current_ids: + if hasattr(tokenizer, "decode"): + decoded_text = tokenizer.decode(current_ids, skip_special_tokens=False) + else: + decoded_text = " ".join([f"token_{id}" for id in current_ids]) + if current_mask == 0: + print(f"{RED}{decoded_text}{RESET}") + else: + print(f"{GREEN}{decoded_text}{RESET}") + print("\n" + "-" * 50) + + +class TestPreprocessing(unittest.TestCase): + """Test suite for conversation preprocessing and loss mask generation.""" + + def setUp(self): + """Set up test fixtures with Qwen3-8B tokenizer and template.""" + self.model_path = "Qwen/Qwen3-8B" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + self.chat_template = TEMPLATE_REGISTRY.get("qwen") + self.max_length = 512 + + def test_conversation_preprocessing_basic(self): + """Test basic conversation preprocessing with assistant response identification.""" + conversations = [ + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "The answer is 4."}, + ] + ] + + results = preprocess_conversations( + tokenizer=self.tokenizer, + conversations=conversations, + chat_template=self.chat_template, + max_length=self.max_length, + is_preformatted=False, + ) + + # Check structure + self.assertIn("input_ids", results) + self.assertIn("loss_mask", results) + self.assertIn("attention_mask", results) + self.assertEqual(len(results["input_ids"]), 1) + self.assertEqual(len(results["loss_mask"]), 1) + self.assertEqual(len(results["attention_mask"]), 1) + + # Verify tensor shapes match + input_ids = results["input_ids"][0].squeeze() + loss_mask = results["loss_mask"][0].squeeze() + attention_mask = results["attention_mask"][0].squeeze() + + self.assertEqual(input_ids.shape, loss_mask.shape) + self.assertEqual(input_ids.shape, attention_mask.shape) + + # Check that some tokens are marked for loss (assistant response) + self.assertTrue( + torch.any(loss_mask == 1), "No tokens marked for loss computation" + ) + + # Check that some tokens are not marked for loss (system/user parts) + self.assertTrue( + torch.any(loss_mask == 0), "All tokens marked for loss computation" + ) + + # Verify the complete assistant response is captured in the loss mask + assistant_token_indices = torch.where(loss_mask == 1)[0] + if len(assistant_token_indices) > 0: + assistant_tokens = input_ids[assistant_token_indices] + assistant_text = self.tokenizer.decode( + assistant_tokens, skip_special_tokens=False + ) + expected_assistant_text = ( + "\n\n\n\nThe answer is 4.<|im_end|>\n" + ) + self.assertEqual( + assistant_text, + expected_assistant_text, + f"Assistant text does not match exactly. Expected: {repr(expected_assistant_text)}, Got: {repr(assistant_text)}", + ) + + def test_multiple_turns_conversation(self): + """Test conversation with multiple user-assistant turns.""" + conversations = [ + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "The answer is 4."}, + {"role": "user", "content": "Are you sure?"}, + {"role": "assistant", "content": "Yes, I'm certain."}, + ] + ] + + results = preprocess_conversations( + tokenizer=self.tokenizer, + conversations=conversations, + chat_template=self.chat_template, + max_length=self.max_length, + is_preformatted=False, + ) + + input_ids = results["input_ids"][0].squeeze() + loss_mask = results["loss_mask"][0].squeeze() + + # Get all assistant response tokens + assistant_token_indices = torch.where(loss_mask == 1)[0] + self.assertTrue( + len(assistant_token_indices) > 0, "No assistant tokens identified" + ) + + # Decode assistant tokens to verify both responses are captured + assistant_tokens = input_ids[assistant_token_indices] + assistant_text = self.tokenizer.decode( + assistant_tokens, skip_special_tokens=False + ) + + # Exact match for the complete assistant text from both turns + expected_assistant_text = "The answer is 4.<|im_end|>\n\n\n\n\nYes, I'm certain.<|im_end|>\n" + self.assertEqual( + assistant_text, + expected_assistant_text, + f"Assistant text does not match exactly. Expected: {repr(expected_assistant_text)}, Got: {repr(assistant_text)}", + ) + + def test_preformatted_conversation(self): + """Test preprocessing of pre-formatted conversation strings.""" + preformatted_conversations = [ + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n" + ] + + results = preprocess_conversations( + tokenizer=self.tokenizer, + conversations=preformatted_conversations, + chat_template=self.chat_template, + max_length=self.max_length, + is_preformatted=True, + ) + + # Check basic structure + self.assertEqual(len(results["input_ids"]), 1) + + input_ids = results["input_ids"][0].squeeze() + loss_mask = results["loss_mask"][0].squeeze() + + # Verify assistant response is identified + self.assertTrue( + torch.any(loss_mask == 1), + "No assistant tokens marked in preformatted input", + ) + + # Extract and verify assistant content + assistant_token_indices = torch.where(loss_mask == 1)[0] + assistant_tokens = input_ids[assistant_token_indices] + assistant_text = self.tokenizer.decode( + assistant_tokens, skip_special_tokens=False + ) + + # Check for exact match of the expected assistant response + expected_assistant_text = "Python is a programming language.<|im_end|>\n" + self.assertEqual( + assistant_text, + expected_assistant_text, + f"Assistant text does not match exactly. Expected: {repr(expected_assistant_text)}, Got: {repr(assistant_text)}", + ) + + def test_assistant_span_boundaries(self): + """Test that assistant span boundaries are correctly identified without truncation.""" + test_cases = [ + { + "name": "Short response", + "conversation": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ], + "expected_assistant_text": "\n\n\n\nHello!<|im_end|>\n", + }, + { + "name": "Response with punctuation", + "conversation": [ + {"role": "user", "content": "What's your name?"}, + {"role": "assistant", "content": "I'm Claude, an AI assistant."}, + ], + "expected_assistant_text": "\n\n\n\nI'm Claude, an AI assistant.<|im_end|>\n", + }, + { + "name": "Multi-sentence response", + "conversation": [ + {"role": "user", "content": "Tell me about Python."}, + { + "role": "assistant", + "content": "Python is a programming language. It's very popular for AI.", + }, + ], + "expected_assistant_text": "\n\n\n\nPython is a programming language. It's very popular for AI.<|im_end|>\n", + }, + { + "name": "Response with special characters", + "conversation": [ + {"role": "user", "content": "Show me math."}, + { + "role": "assistant", + "content": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.", + }, + ], + "expected_assistant_text": "\n\n\n\nSure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.<|im_end|>\n", + }, + ] + + for test_case in test_cases: + with self.subTest(test_case["name"]): + conversations = [test_case["conversation"]] + + results = preprocess_conversations( + tokenizer=self.tokenizer, + conversations=conversations, + chat_template=self.chat_template, + max_length=self.max_length, + is_preformatted=False, + ) + + input_ids = results["input_ids"][0].squeeze() + loss_mask = results["loss_mask"][0].squeeze() + + # Extract assistant tokens + assistant_token_indices = torch.where(loss_mask == 1)[0] + self.assertTrue( + len(assistant_token_indices) > 0, + f"No assistant tokens found for test case: {test_case['name']}", + ) + + assistant_tokens = input_ids[assistant_token_indices] + assistant_text = self.tokenizer.decode( + assistant_tokens, skip_special_tokens=False + ) + + # Verify exact match of the expected assistant text + expected_assistant_text = test_case["expected_assistant_text"] + self.assertEqual( + assistant_text, + expected_assistant_text, + f"Assistant text does not match exactly for test case '{test_case['name']}'. Expected: {repr(expected_assistant_text)}, Got: {repr(assistant_text)}", + ) + + # Additional check: ensure no user content leaked into assistant spans + user_content = test_case["conversation"][0]["content"] + # Check if user content appears in assistant text (should not happen with exact matching) + self.assertNotIn( + user_content, + assistant_text, + f"User content '{user_content}' found in assistant spans for test case '{test_case['name']}': '{assistant_text}'", + ) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + + suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPreprocessing)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) + + # Commented-out example for using visualize_loss_mask function directly + """ + # Example usage of visualize_loss_mask for debugging/visualization + model_path = "Qwen/Qwen3-8B" + tokenizer = AutoTokenizer.from_pretrained(model_path) + chat_template = TEMPLATE_REGISTRY.get("qwen") + + # Using conversations list + conversations = [ + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "The answer is 4."}, + {"role": "user", "content": "I don't think that's right"}, + {"role": "assistant", "content": "I'm pretty sure it's 4."}, + ], + [ + {"role": "user", "content": "How do you boil water?"}, + {"role": "assistant", "content": "Use a stove."}, + ], + ] + results = preprocess_conversations( + tokenizer=tokenizer, + conversations=conversations, + chat_template=chat_template, + max_length=512, + is_preformatted=False, + ) + for i in range(len(results["input_ids"])): + visualize_loss_mask(tokenizer, results["input_ids"][i], results["loss_mask"][i]) + + # Using preformatted conversation + preformatted_conversations = [ + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nI don't think that's right<|im_end|>\n<|im_start|>assistant\n\nI know 2+2 is 4\n\nI'm pretty sure it's 4.<|im_end|>\n", + ] + results = preprocess_conversations( + tokenizer=tokenizer, + conversations=preformatted_conversations, + chat_template=chat_template, + max_length=512, + is_preformatted=True, + ) + for i in range(len(results["input_ids"])): + visualize_loss_mask(tokenizer, results["input_ids"][i], results["loss_mask"][i]) + """ diff --git a/progress/github/SpecForge/tests/test_data/test_references/deepseek-v32_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/deepseek-v32_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..3aadda5a94d02c25626d394173a7d587400c57b7 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/deepseek-v32_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [0, 128803, 18387, 477, 440, 33, 128804, 6759, 2329, 344, 1646, 19566, 20, 16, 1, 128803, 4117, 3072, 477, 440, 33, 128804, 779, 1737, 3072, 16, 1], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/deepseek-v3_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/deepseek-v3_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..1d759a83637eba8dab677fbe4773ed2277519076 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/deepseek-v3_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [0, 3476, 477, 260, 11502, 22896, 16, 128803, 18387, 477, 440, 33, 128804, 6759, 2329, 344, 1646, 19566, 20, 16, 1, 128803, 4117, 3072, 477, 440, 33, 128804, 779, 1737, 3072, 16, 1], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/gpt-oss_gpt-oss_ref.json b/progress/github/SpecForge/tests/test_data/test_references/gpt-oss_gpt-oss_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..88344bf001ebb7891ab9b437d4683fd913cadfbf --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/gpt-oss_gpt-oss_ref.json @@ -0,0 +1 @@ +{"input_ids": [200006, 17360, 200008, 3575, 553, 17554, 162016, 11, 261, 4410, 6439, 2359, 22203, 656, 7788, 17527, 558, 87447, 100594, 25, 220, 1323, 19, 12, 3218, 198, 6576, 3521, 25, 220, 1323, 20, 12, 3218, 12, 2029, 279, 30377, 289, 25, 4465, 279, 2, 13888, 18403, 25, 8450, 11, 49159, 11, 1721, 13, 21030, 2804, 413, 7360, 395, 1753, 3176, 13, 200007, 200006, 1428, 200008, 176289, 90765, 48711, 13, 200007, 200006, 173781, 200005, 35644, 200008, 976, 1825, 10648, 261, 18522, 328, 48889, 35438, 13, 357, 1757, 4321, 20485, 3161, 12608, 25399, 536, 326, 44942, 30540, 13, 200007, 200006, 173781, 200005, 17196, 200008, 170939, 35438, 382, 290, 5012, 328, 7165, 326, 5954, 540, 290, 1645, 18864, 3211, 1008, 200007, 200006, 1428, 200008, 176289, 90765, 48711, 13, 200007, 200006, 173781, 200005, 17196, 200008, 15390, 1486, 11027, 200007], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-instruct_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-instruct_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..932853d41e351337e961aff93256580cc95ae0a9 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-instruct_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [163594, 14062, 163601, 3900, 554, 261, 13205, 26626, 13, 163586, 163587, 2482, 163601, 24328, 554, 398, 30, 163586, 163588, 69702, 163601, 6725, 1530, 387, 1999, 33249, 17, 13, 163586, 163587, 2482, 163601, 6034, 3410, 554, 398, 30, 163586, 163588, 69702, 163601, 1228, 2285, 3410, 13, 163586], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-thinking_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-thinking_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..dc35abd43f6cc3f265b3b8f1d82fc5eb506565c7 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/kimi-k2-thinking_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [163594, 14062, 163601, 3900, 554, 261, 13205, 26626, 13, 163586, 163587, 2482, 163601, 24328, 554, 398, 30, 163586, 163588, 69702, 163601, 163606, 163607, 6725, 1530, 387, 1999, 33249, 17, 13, 163586, 163587, 2482, 163601, 6034, 3410, 554, 398, 30, 163586, 163588, 69702, 163601, 1228, 2285, 3410, 13, 163586], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/ling-flash-2.0_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/ling-flash-2.0_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..b0b4a4e1988ec0666795f1cc113436bb97fb019e --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/ling-flash-2.0_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [157151, 90827, 157152, 2496, 449, 259, 9031, 16841, 13, 198, 14136, 5381, 6350, 928, 156895, 157151, 39, 116171, 157152, 13328, 449, 362, 30, 156895, 157151, 8469, 7342, 5468, 157152, 4653, 1717, 341, 1834, 36364, 17, 13, 156895, 157151, 39, 116171, 157152, 3115, 2622, 449, 362, 30, 156895, 157151, 8469, 7342, 5468, 157152, 16, 16, 1594, 2622, 13, 156895], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..2d8de2ea90f5ada05e36e9fbc71304ebab64e6e0 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 15191, 525, 498, 30, 151645, 198, 151644, 77091, 198, 5050, 829, 374, 1207, 16948, 17, 13, 151645, 198, 151644, 872, 198, 4340, 2310, 525, 498, 30, 151645, 198, 151644, 77091, 198, 151667, 271, 151668, 271, 16, 16, 1635, 2310, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_tool-use_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_tool-use_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..6272088047096d71cc59071caf691b448c613e7e --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen3-instruct_tool-use_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 3838, 594, 279, 9104, 1075, 304, 26549, 3351, 30, 151645, 198, 151644, 77091, 198, 40, 3278, 1779, 279, 1482, 9104, 304, 26549, 369, 498, 13, 151645, 198, 151644, 872, 198, 151665, 198, 4913, 2527, 788, 330, 3430, 23649, 497, 330, 34558, 788, 220, 17, 17, 11, 330, 9056, 788, 330, 50, 27297, 16707, 151666, 151645, 198, 151644, 77091, 198, 785, 1482, 9104, 304, 26549, 374, 39698, 448, 264, 9315, 315, 220, 17, 17, 30937, 13, 151645, 198, 151644, 872, 198, 151665, 198, 4913, 3843, 788, 330, 34, 40247, 497, 330, 58984, 788, 330, 14008, 49293, 678, 1899, 1189, 532, 151666, 198, 151665, 198, 4913, 3843, 788, 330, 34, 40247, 497, 330, 58984, 788, 330, 14008, 49293, 678, 1899, 1189, 532, 151666, 151645, 198, 151644, 872, 198, 21396, 0, 2980, 498, 1083, 3291, 752, 421, 432, 686, 11174, 16577, 30, 151645, 198, 151644, 77091, 198, 151667, 271, 151668, 271, 28715, 389, 279, 17595, 11, 1052, 686, 387, 902, 11174, 16577, 2293, 17119, 2797, 49293, 678, 1899, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen3-next-thinking_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen3-next-thinking_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..fe12355606c8c63f489913753e215863f4fba1ce --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen3-next-thinking_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 15191, 525, 498, 30, 151645, 198, 151644, 77091, 198, 5050, 829, 374, 1207, 16948, 17, 13, 151645, 198, 151644, 872, 198, 4340, 2310, 525, 498, 30, 151645, 198, 151644, 77091, 198, 151667, 198, 16, 16, 1635, 2310, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen3-thinking_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen3-thinking_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..61885ced3006f2cdc6334656511caf373361bb0a --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen3-thinking_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 15191, 525, 498, 30, 151645, 198, 151644, 77091, 198, 5050, 829, 374, 1207, 16948, 17, 13, 151645, 198, 151644, 872, 198, 4340, 2310, 525, 498, 30, 151645, 198, 151644, 77091, 198, 16, 16, 1635, 2310, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..61885ced3006f2cdc6334656511caf373361bb0a --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 15191, 525, 498, 30, 151645, 198, 151644, 77091, 198, 5050, 829, 374, 1207, 16948, 17, 13, 151645, 198, 151644, 872, 198, 4340, 2310, 525, 498, 30, 151645, 198, 151644, 77091, 198, 16, 16, 1635, 2310, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/qwen_tool-use_ref.json b/progress/github/SpecForge/tests/test_data/test_references/qwen_tool-use_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..b9b96e47476f270d9bbcfc42f93b40b5ec46aebe --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/qwen_tool-use_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 429, 646, 2711, 279, 3482, 323, 1779, 279, 9104, 13, 151645, 198, 151644, 872, 198, 3838, 594, 279, 9104, 1075, 304, 26549, 3351, 30, 151645, 198, 151644, 77091, 198, 40, 3278, 1779, 279, 1482, 9104, 304, 26549, 369, 498, 13, 151645, 198, 151644, 872, 198, 151665, 198, 4913, 2527, 788, 330, 3430, 23649, 497, 330, 34558, 788, 220, 17, 17, 11, 330, 9056, 788, 330, 50, 27297, 16707, 151666, 151645, 198, 151644, 77091, 198, 785, 1482, 9104, 304, 26549, 374, 39698, 448, 264, 9315, 315, 220, 17, 17, 30937, 13, 151645, 198, 151644, 872, 198, 151665, 198, 4913, 3843, 788, 330, 34, 40247, 497, 330, 58984, 788, 330, 14008, 49293, 678, 1899, 1189, 532, 151666, 198, 151665, 198, 4913, 3843, 788, 330, 34, 40247, 497, 330, 58984, 788, 330, 14008, 49293, 678, 1899, 1189, 532, 151666, 151645, 198, 151644, 872, 198, 21396, 0, 2980, 498, 1083, 3291, 752, 421, 432, 686, 11174, 16577, 30, 151645, 198, 151644, 77091, 198, 151667, 271, 151668, 271, 28715, 389, 279, 17595, 11, 1052, 686, 387, 902, 11174, 16577, 2293, 17119, 2797, 49293, 678, 1899, 13, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_data/test_references/repo-wiki_standard_ref.json b/progress/github/SpecForge/tests/test_data/test_references/repo-wiki_standard_ref.json new file mode 100644 index 0000000000000000000000000000000000000000..a12f427fe74fbd1f9afb08e3b9bcbc7c7feeed85 --- /dev/null +++ b/progress/github/SpecForge/tests/test_data/test_references/repo-wiki_standard_ref.json @@ -0,0 +1 @@ +{"input_ids": [151644, 872, 271, 334, 98743, 25, 1446, 27732, 5889, 304, 8453, 7, 104811, 8, 659, 56177, 27, 5778, 397, 2610, 525, 458, 20685, 10916, 9705, 11470, 448, 5538, 18726, 304, 3162, 4401, 11, 1849, 2884, 11, 323, 15754, 3139, 25262, 13, 4615, 35874, 15448, 304, 41018, 6351, 2038, 78267, 323, 6825, 9705, 14389, 429, 5165, 36967, 10916, 6540, 1119, 41679, 11, 91078, 6832, 42914, 13, 1446, 3535, 429, 24364, 9705, 17045, 438, 279, 14164, 1948, 2038, 23094, 323, 15754, 25148, 624, 522, 5778, 1339, 27, 8202, 8467, 397, 7771, 8954, 374, 311, 23643, 279, 3897, 12542, 323, 6923, 264, 15817, 9705, 6220, 5944, 429, 17045, 438, 279, 16266, 369, 264, 1879, 14800, 9705, 3910, 13, 1096, 5944, 1969, 27968, 311, 13402, 3941, 678, 3139, 5866, 11, 504, 67458, 10887, 3974, 389, 37569, 311, 11647, 22703, 11682, 5785, 7236, 382, 334, 10234, 419, 12850, 95518, 8325, 12, 51143, 9705, 27957, 25271, 15754, 389, 37569, 882, 11, 42054, 1824, 22305, 11, 323, 14177, 973, 4565, 24376, 13, 4615, 6358, 686, 8253, 1246, 13444, 7263, 646, 3535, 11, 4211, 11, 323, 13036, 419, 2038, 3152, 624, 522, 8202, 8467, 1339, 5338, 11, 1401, 1119, 279, 2701, 1995, 911, 279, 12542, 498, 1184, 311, 975, 389, 1447, 4624, 1034, 6220, 4916, 510, 27, 4987, 38283, 397, 624, 144663, 16991, 8315, 198, 262, 80493, 8315, 2972, 198, 286, 22612, 242, 16991, 2943, 18002, 198, 262, 80493, 8315, 4030, 198, 286, 80493, 3538, 18002, 198, 286, 22612, 242, 16991, 3538, 4452, 18002, 198, 262, 80493, 5439, 198, 286, 80493, 5439, 18002, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 262, 22612, 242, 16991, 1887, 18002, 198, 144663, 16991, 1936, 21492, 198, 262, 80493, 5439, 198, 286, 80493, 5439, 18002, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 262, 80493, 4772, 2972, 198, 286, 80493, 2943, 18002, 198, 286, 80493, 9109, 18002, 198, 286, 22612, 242, 16991, 7497, 18002, 198, 262, 80493, 4772, 6507, 198, 286, 22612, 242, 16991, 4119, 18002, 198, 262, 80493, 4772, 4030, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 3538, 18002, 198, 286, 22612, 242, 16991, 3538, 4452, 18002, 198, 262, 80493, 4772, 4314, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 3553, 18002, 198, 286, 22612, 242, 16991, 3553, 4452, 18002, 198, 262, 80493, 4772, 1313, 198, 286, 80493, 1638, 91578, 18002, 198, 286, 80493, 26588, 91578, 18002, 198, 286, 80493, 2415, 18002, 198, 286, 22612, 242, 16991, 2415, 4452, 18002, 198, 262, 22612, 242, 16991, 1887, 18002, 198, 144663, 16991, 2193, 198, 262, 80493, 8315, 198, 286, 22612, 242, 16991, 2331, 33406, 198, 262, 80493, 1936, 21492, 198, 286, 22612, 242, 16991, 2331, 33406, 198, 262, 80493, 6238, 198, 286, 22612, 242, 16991, 2331, 33406, 198, 262, 80493, 13291, 198, 286, 22612, 242, 16991, 2331, 33406, 198, 262, 22612, 242, 16991, 28331, 198, 286, 22612, 242, 16991, 2331, 33406, 198, 144663, 16991, 6200, 198, 262, 80493, 23404, 2733, 18002, 198, 262, 80493, 20882, 18002, 198, 262, 80493, 20882, 4452, 18002, 198, 262, 80493, 4078, 5191, 18002, 198, 262, 80493, 4078, 5191, 4452, 18002, 198, 262, 80493, 37664, 18002, 198, 262, 80493, 3546, 8296, 18002, 198, 262, 80493, 3546, 8296, 4452, 18002, 198, 262, 80493, 2270, 466, 824, 18002, 198, 262, 80493, 2270, 466, 824, 4452, 18002, 198, 262, 80493, 14397, 8467, 18002, 198, 262, 80493, 14397, 8467, 4452, 18002, 198, 262, 80493, 14397, 842, 18002, 198, 262, 80493, 14397, 842, 4452, 18002, 198, 262, 80493, 14397, 3109, 18002, 198, 262, 80493, 14397, 3109, 4452, 18002, 198, 262, 22612, 242, 16991, 6573, 8950, 18002, 198, 144663, 16991, 26588, 198, 262, 80493, 8315, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 1936, 21492, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 58113, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 6238, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 13291, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 1273, 3848, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 80493, 28331, 198, 286, 22612, 242, 16991, 40549, 1192, 198, 262, 22612, 242, 16991, 6505, 62, 73561, 2395, 198, 144663, 16991, 10295, 198, 262, 80493, 3483, 18855, 198, 286, 80493, 2193, 198, 310, 80493, 8315, 198, 394, 22612, 242, 16991, 4401, 33406, 198, 310, 80493, 1936, 21492, 198, 394, 22612, 242, 16991, 4401, 33406, 198, 310, 80493, 6238, 198, 394, 22612, 242, 16991, 4401, 33406, 198, 310, 80493, 13291, 198, 394, 22612, 242, 16991, 4401, 33406, 198, 310, 22612, 242, 16991, 28331, 198, 394, 22612, 242, 16991, 4401, 33406, 198, 286, 80493, 61945, 21324, 198, 286, 80493, 8315, 11667, 4090, 2395, 198, 286, 80493, 8315, 11667, 4906, 15847, 2395, 198, 286, 80493, 8315, 23241, 4090, 2395, 198, 286, 80493, 8315, 23241, 4906, 15847, 2395, 198, 286, 80493, 58113, 4090, 2395, 198, 286, 80493, 58113, 4906, 15847, 2395, 198, 286, 22612, 242, 16991, 58113, 4906, 80143, 2395, 198, 262, 22612, 242, 16991, 595, 23, 82, 198, 286, 80493, 61945, 21324, 198, 286, 22612, 242, 16991, 16661, 4323, 198, 144663, 16991, 33765, 198, 262, 80493, 2193, 198, 286, 80493, 8315, 33406, 198, 286, 80493, 1936, 21492, 33406, 198, 286, 80493, 6238, 33406, 198, 286, 80493, 13291, 33406, 198, 286, 22612, 242, 16991, 28331, 33406, 198, 262, 80493, 19911, 198, 286, 80493, 13009, 33406, 198, 286, 80493, 1936, 21492, 33406, 198, 286, 80493, 2193, 33406, 198, 286, 80493, 32372, 33406, 198, 286, 80493, 13291, 33406, 198, 286, 80493, 1273, 3848, 33406, 198, 286, 22612, 242, 16991, 90982, 33406, 198, 262, 80493, 21266, 33406, 198, 262, 22612, 242, 16991, 2750, 33406, 198, 144663, 16991, 3051, 198, 262, 80493, 19163, 198, 286, 80493, 19163, 7650, 198, 310, 22612, 242, 16991, 5975, 18002, 198, 286, 80493, 342, 4837, 20942, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 22612, 242, 16991, 342, 4837, 18002, 198, 286, 80493, 305, 34378, 20942, 198, 310, 80493, 3482, 71, 34378, 198, 394, 80493, 2943, 18002, 198, 394, 80493, 2943, 4452, 18002, 198, 394, 80493, 2193, 18002, 198, 394, 22612, 242, 16991, 2951, 18002, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 22612, 242, 16991, 2193, 18002, 198, 286, 80493, 1758, 20942, 198, 310, 80493, 1758, 18002, 198, 310, 22612, 242, 16991, 1758, 4452, 18002, 198, 286, 80493, 829, 2343, 198, 310, 80493, 1815, 261, 18002, 198, 310, 22612, 242, 16991, 1815, 261, 4452, 18002, 198, 286, 80493, 19424, 20942, 198, 310, 80493, 4763, 198, 394, 22612, 242, 16991, 4763, 18002, 198, 310, 80493, 23404, 2972, 18002, 198, 310, 80493, 23404, 2972, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 80493, 2193, 4452, 18002, 198, 310, 80493, 4772, 2972, 18002, 198, 310, 22612, 242, 16991, 4772, 2972, 4452, 18002, 198, 286, 80493, 274, 18, 20942, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 22612, 242, 16991, 274, 18, 18002, 198, 286, 80493, 12455, 20942, 198, 310, 80493, 61945, 21324, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 22612, 242, 16991, 2193, 18002, 198, 286, 80493, 5704, 20942, 198, 310, 80493, 28431, 198, 394, 22612, 242, 16991, 2943, 4452, 18002, 198, 310, 80493, 61945, 21324, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 22612, 242, 16991, 10802, 18002, 198, 286, 80493, 1273, 3848, 198, 310, 80493, 2943, 18002, 198, 310, 80493, 2943, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 80493, 3538, 18002, 198, 310, 22612, 242, 16991, 3538, 4452, 18002, 198, 286, 80493, 2943, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 18021, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 6645, 18002, 198, 286, 80493, 6645, 4452, 18002, 198, 286, 80493, 60829, 18002, 198, 286, 80493, 2606, 18002, 198, 286, 80493, 3059, 18002, 198, 286, 22612, 242, 16991, 42166, 18002, 198, 262, 80493, 23404, 17168, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 2053, 85424, 18002, 198, 286, 22612, 242, 16991, 2053, 85424, 4452, 18002, 198, 262, 80493, 6644, 615, 4466, 198, 286, 80493, 5476, 67, 198, 310, 80493, 2943, 18002, 198, 310, 22612, 242, 16991, 2193, 18002, 198, 286, 80493, 26588, 75841, 198, 310, 80493, 21348, 18002, 198, 310, 22612, 242, 16991, 2193, 18002, 198, 286, 22612, 242, 16991, 8633, 18002, 198, 262, 80493, 26588, 29172, 198, 286, 80493, 8317, 198, 310, 80493, 5975, 18002, 198, 310, 80493, 926, 35403, 261, 18002, 198, 310, 80493, 926, 35403, 261, 4452, 18002, 198, 310, 80493, 25991, 35403, 261, 18002, 198, 310, 80493, 25991, 35403, 261, 4452, 18002, 198, 310, 80493, 7497, 18002, 198, 310, 22612, 242, 16991, 8317, 261, 18002, 198, 286, 80493, 76899, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 83232, 18002, 198, 286, 80493, 83232, 84245, 4452, 18002, 198, 286, 80493, 11160, 18002, 198, 286, 80493, 11160, 4452, 18002, 198, 286, 80493, 12716, 18002, 198, 286, 80493, 12716, 4452, 18002, 198, 286, 80493, 5819, 20602, 18002, 198, 286, 80493, 5819, 20602, 4452, 18002, 198, 286, 80493, 1273, 6031, 4452, 18002, 198, 286, 22612, 242, 16991, 66563, 18002, 198, 262, 80493, 702, 4079, 287, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 27879, 34683, 18002, 198, 286, 80493, 27879, 34683, 4452, 18002, 198, 286, 80493, 10058, 18002, 198, 286, 80493, 10058, 4452, 18002, 198, 286, 22612, 242, 16991, 7497, 18002, 198, 262, 80493, 2820, 2028, 198, 286, 80493, 40915, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 4051, 18002, 198, 286, 80493, 4051, 4452, 18002, 198, 286, 80493, 1140, 18002, 198, 286, 80493, 8718, 18002, 198, 286, 80493, 8718, 4452, 18002, 198, 286, 80493, 27879, 18002, 198, 286, 80493, 27879, 8727, 18002, 198, 286, 80493, 27879, 8727, 4452, 18002, 198, 286, 80493, 27879, 4452, 18002, 198, 286, 80493, 1584, 18002, 198, 286, 80493, 1584, 4452, 18002, 198, 286, 22612, 242, 16991, 7497, 18002, 198, 262, 80493, 3468, 1607, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 1140, 18002, 198, 286, 22612, 242, 16991, 1140, 4452, 18002, 198, 262, 80493, 18026, 86, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 65551, 57445, 18002, 198, 286, 80493, 65551, 57445, 4452, 18002, 198, 286, 22612, 242, 16991, 1273, 6031, 4452, 18002, 198, 262, 80493, 2270, 466, 69, 11706, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 2193, 4452, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 13823, 18002, 198, 286, 22612, 242, 16991, 13823, 4452, 18002, 198, 262, 80493, 29679, 198, 286, 80493, 29679, 18002, 198, 286, 22612, 242, 16991, 29679, 4452, 18002, 198, 262, 80493, 52995, 44848, 198, 286, 80493, 4772, 9995, 1693, 198, 310, 80493, 31558, 18002, 198, 310, 80493, 31558, 4452, 18002, 198, 310, 80493, 37664, 18002, 198, 310, 80493, 1299, 6295, 18002, 198, 310, 80493, 1299, 6295, 4452, 18002, 198, 310, 80493, 3553, 18002, 198, 310, 80493, 3553, 4452, 18002, 198, 310, 80493, 3383, 18002, 198, 310, 22612, 242, 16991, 7497, 18002, 198, 286, 80493, 3270, 1419, 198, 310, 80493, 31558, 18002, 198, 310, 80493, 31558, 4452, 18002, 198, 310, 80493, 37664, 18002, 198, 310, 80493, 3239, 18002, 198, 310, 80493, 3553, 18002, 198, 310, 80493, 3553, 4452, 18002, 198, 310, 80493, 3383, 18002, 198, 310, 22612, 242, 16991, 7497, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 5975, 18002, 198, 286, 80493, 24099, 18002, 198, 286, 80493, 6645, 18002, 198, 286, 22612, 242, 16991, 6645, 4452, 18002, 198, 262, 80493, 3553, 198, 286, 80493, 2331, 198, 310, 80493, 733, 18002, 198, 310, 80493, 5975, 18002, 198, 310, 80493, 1034, 9078, 18002, 198, 310, 80493, 1034, 9078, 4452, 18002, 198, 310, 80493, 1034, 5376, 18002, 198, 310, 80493, 1034, 5376, 4452, 18002, 198, 310, 80493, 1034, 10287, 18002, 198, 310, 80493, 1034, 10287, 4452, 18002, 198, 310, 80493, 1034, 6443, 18189, 18002, 198, 310, 80493, 1034, 14809, 18002, 198, 310, 80493, 37664, 18002, 198, 310, 22612, 242, 16991, 1273, 6031, 4452, 18002, 198, 286, 80493, 11160, 198, 310, 80493, 1537, 12759, 3009, 18002, 198, 310, 80493, 1537, 12759, 3009, 4452, 18002, 198, 310, 80493, 11160, 18002, 198, 310, 80493, 22334, 18002, 198, 310, 80493, 22334, 4452, 18002, 198, 310, 80493, 30575, 5490, 18002, 198, 310, 22612, 242, 16991, 30575, 5490, 4452, 18002, 198, 286, 80493, 2162, 35939, 14809, 18002, 198, 286, 80493, 2162, 35939, 14809, 4452, 18002, 198, 286, 80493, 2162, 14809, 18002, 198, 286, 80493, 2162, 14809, 4452, 18002, 198, 286, 80493, 6500, 14809, 18002, 198, 286, 80493, 21290, 18002, 198, 286, 80493, 21290, 4452, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 1034, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 4285, 14809, 18002, 198, 286, 80493, 4285, 14809, 4452, 18002, 198, 286, 80493, 7497, 18002, 198, 286, 80493, 8135, 14809, 18002, 198, 286, 22612, 242, 16991, 12439, 18002, 198, 262, 80493, 30575, 198, 286, 80493, 4179, 49710, 440, 684, 198, 310, 80493, 2193, 18002, 198, 310, 80493, 4357, 18002, 198, 310, 80493, 16851, 18002, 198, 310, 80493, 16851, 4452, 18002, 198, 310, 80493, 7497, 18002, 198, 310, 22612, 242, 16991, 4094, 18002, 198, 286, 80493, 28809, 198, 310, 80493, 21483, 4584, 198, 394, 80493, 7177, 18002, 198, 394, 22612, 242, 16991, 7177, 4452, 18002, 198, 310, 80493, 5353, 261, 198, 394, 80493, 5353, 261, 18002, 198, 394, 22612, 242, 16991, 5353, 261, 4452, 18002, 198, 310, 80493, 4534, 198, 394, 80493, 2193, 18002, 198, 394, 80493, 4534, 18002, 198, 394, 80493, 4534, 4452, 18002, 198, 394, 80493, 12418, 45159, 18002, 198, 394, 80493, 12418, 45159, 4452, 18002, 198, 394, 80493, 37664, 18002, 198, 394, 80493, 1424, 927, 4407, 18002, 198, 394, 80493, 1424, 927, 4407, 4452, 18002, 198, 394, 22612, 242, 16991, 1943, 18002, 198, 310, 80493, 4534, 2454, 198, 394, 80493, 2193, 18002, 198, 394, 80493, 1584, 18002, 198, 394, 22612, 242, 16991, 1584, 4452, 18002, 198, 310, 80493, 6845, 198, 394, 80493, 4349, 66, 485, 719, 198, 503, 80493, 1638, 22773, 18002, 198, 503, 80493, 6645, 18002, 198, 503, 80493, 6645, 4452, 18002, 198, 503, 80493, 4842, 18002, 198, 503, 22612, 242, 16991, 8848, 267, 12978, 22773, 18002, 198, 394, 80493, 2193, 18002, 198, 394, 80493, 38799, 18002, 198, 394, 80493, 38799, 4452, 18002, 198, 394, 80493, 14397, 18002, 198, 394, 80493, 12811, 13996, 2566, 18002, 198, 394, 80493, 12811, 13996, 2566, 4452, 18002, 198, 394, 22612, 242, 16991, 30575, 12759, 1670, 28058, 18002, 198, 310, 80493, 30575, 839, 198, 394, 22612, 242, 16991, 5925, 18002, 198, 310, 80493, 42112, 18002, 198, 310, 80493, 42112, 4452, 18002, 198, 310, 80493, 2193, 18002, 198, 310, 80493, 54717, 18002, 198, 310, 80493, 4357, 18002, 198, 310, 80493, 4357, 4452, 18002, 198, 310, 80493, 18646, 18002, 198, 310, 80493, 28809, 18002, 198, 310, 80493, 28809, 4452, 18002, 198, 310, 80493, 1584, 18002, 198, 310, 22612, 242, 16991, 1273, 6031, 4452, 18002, 198, 286, 22612, 242, 16991, 5819, 198, 310, 80493, 8315, 16172, 198, 394, 80493, 37664, 18002, 198, 394, 80493, 9666, 18002, 198, 394, 80493, 30575, 18002, 198, 394, 80493, 30575, 42873, 18002, 198, 394, 80493, 30575, 42873, 4452, 18002, 198, 394, 22612, 242, 16991, 30575, 4452, 18002, 198, 310, 80493, 6238, 16172, 198, 394, 80493, 30575, 18002, 198, 394, 80493, 30575, 42873, 18002, 198, 394, 80493, 30575, 42873, 4452, 18002, 198, 394, 22612, 242, 16991, 30575, 4452, 18002, 198, 310, 80493, 4349, 66, 485, 998, 198, 394, 80493, 4147, 18002, 198, 394, 22612, 242, 16991, 1034, 18002, 198, 310, 80493, 37664, 18002, 198, 310, 80493, 5819, 18002, 198, 310, 80493, 30575, 3109, 18002, 198, 310, 22612, 242, 16991, 30575, 3109, 4452, 18002, 198, 262, 22612, 242, 16991, 41730, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 144663, 16991, 2205, 1999, 198, 262, 80493, 17063, 198, 286, 80493, 220, 15, 15, 15, 15, 16, 9372, 9995, 1693, 6137, 18002, 198, 286, 22612, 242, 16991, 220, 15, 15, 15, 15, 17, 9165, 1419, 6137, 18002, 198, 262, 80493, 2193, 18002, 198, 262, 80493, 4625, 18002, 198, 262, 22612, 242, 16991, 37664, 18002, 198, 144663, 16991, 16734, 198, 262, 80493, 2193, 18002, 198, 262, 80493, 8386, 18002, 198, 262, 80493, 296, 18, 18002, 198, 262, 80493, 16734, 18002, 198, 262, 22612, 242, 16991, 10472, 67, 18002, 198, 144663, 16991, 68909, 198, 262, 80493, 8315, 14, 8092, 2972, 198, 286, 22612, 242, 16991, 2943, 18002, 198, 262, 80493, 1936, 21492, 198, 286, 80493, 4772, 2972, 198, 310, 80493, 2943, 18002, 198, 310, 22612, 242, 16991, 9109, 18002, 198, 286, 80493, 4772, 4314, 198, 310, 80493, 1034, 4314, 18002, 198, 310, 22612, 242, 16991, 3553, 18002, 198, 286, 22612, 242, 16991, 4772, 1313, 198, 310, 22612, 242, 16991, 24036, 48943, 18002, 198, 262, 80493, 3051, 198, 286, 80493, 19163, 198, 310, 80493, 342, 4837, 20942, 198, 394, 22612, 242, 16991, 342, 4837, 18002, 198, 310, 80493, 305, 34378, 20942, 21808, 71, 34378, 198, 394, 22612, 242, 16991, 2943, 18002, 198, 310, 80493, 274, 18, 20942, 198, 394, 22612, 242, 16991, 274, 18, 18002, 198, 310, 22612, 242, 16991, 2943, 18002, 198, 286, 80493, 6644, 615, 4466, 198, 310, 80493, 5476, 67, 198, 394, 22612, 242, 16991, 2943, 18002, 198, 310, 80493, 26588, 75841, 198, 394, 22612, 242, 16991, 26588, 2972, 18002, 198, 310, 22612, 242, 16991, 8633, 18002, 198, 286, 80493, 26588, 29172, 14, 24188, 198, 310, 22612, 242, 16991, 6532, 17366, 596, 802, 261, 18002, 198, 286, 80493, 702, 4079, 287, 198, 310, 80493, 10058, 18002, 198, 310, 22612, 242, 16991, 55727, 18002, 198, 286, 80493, 2820, 2028, 198, 310, 80493, 40915, 18002, 198, 310, 80493, 4051, 18002, 198, 310, 22612, 242, 16991, 27879, 5315, 18002, 198, 286, 80493, 3468, 1607, 198, 310, 22612, 242, 16991, 1140, 18002, 198, 286, 80493, 52995, 44848, 198, 310, 80493, 4772, 9995, 1693, 198, 394, 22612, 242, 16991, 8699, 16112, 18002, 198, 310, 80493, 31558, 18002, 198, 310, 80493, 6645, 18002, 198, 310, 80493, 3553, 18002, 198, 310, 22612, 242, 16991, 3383, 18002, 198, 286, 80493, 3553, 198, 310, 22612, 242, 16991, 1461, 485, 329, 18189, 18002, 198, 286, 22612, 242, 16991, 30575, 2687, 15222, 198, 310, 80493, 18646, 4788, 15222, 18002, 198, 310, 22612, 242, 16991, 28809, 18002, 198, 262, 80493, 6238, 34827, 2972, 198, 286, 80493, 2943, 18002, 198, 286, 80493, 2943, 48943, 18002, 198, 286, 80493, 10652, 2972, 18002, 198, 286, 80493, 10652, 19979, 18002, 198, 286, 22612, 242, 16991, 9109, 18002, 198, 262, 80493, 28331, 198, 286, 80493, 21483, 2972, 198, 310, 22612, 242, 16991, 2943, 18002, 198, 286, 80493, 2270, 466, 69, 509, 1451, 198, 310, 22612, 242, 16991, 2943, 18002, 198, 286, 80493, 6238, 4314, 198, 310, 22612, 242, 16991, 3553, 18002, 198, 286, 22612, 242, 16991, 14397, 4314, 198, 310, 22612, 242, 16991, 3553, 18002, 198, 262, 22612, 242, 16991, 12439, 198, 286, 80493, 7681, 454, 198, 310, 80493, 9873, 8202, 18002, 198, 310, 22612, 242, 16991, 3383, 41736, 18002, 198, 286, 22612, 242, 16991, 54320, 628, 321, 198, 310, 22612, 242, 16991, 4778, 32981, 712, 18002, 198, 144663, 16991, 70482, 198, 262, 80493, 2193, 198, 286, 80493, 8315, 18002, 198, 286, 80493, 2331, 18002, 198, 286, 80493, 1936, 21492, 18002, 198, 286, 80493, 1638, 18002, 198, 286, 80493, 6238, 18002, 198, 286, 80493, 13291, 18002, 198, 286, 22612, 242, 16991, 28331, 18002, 198, 262, 22612, 242, 16991, 70482, 18002, 198, 144663, 16991, 6238, 198, 262, 80493, 23404, 2972, 198, 286, 80493, 2943, 18002, 198, 286, 80493, 10652, 8179, 18002, 198, 286, 80493, 5975, 18002, 198, 286, 80493, 9109, 18002, 198, 286, 22612, 242, 16991, 82257, 18002, 198, 262, 80493, 23404, 4030, 198, 286, 80493, 10652, 8179, 4452, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 3538, 18002, 198, 286, 80493, 3538, 4452, 18002, 198, 286, 80493, 1273, 6031, 4452, 18002, 198, 286, 80493, 82257, 18002, 198, 286, 80493, 12439, 18002, 198, 286, 22612, 242, 16991, 12439, 4452, 18002, 198, 262, 80493, 5439, 198, 286, 80493, 5439, 18002, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 262, 22612, 242, 16991, 1887, 18002, 198, 144663, 16991, 18433, 4322, 17, 79, 198, 262, 22612, 242, 16991, 281, 17, 79, 57322, 198, 144663, 16991, 13291, 198, 262, 80493, 5439, 198, 286, 80493, 5439, 18002, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 262, 80493, 13291, 4030, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 90477, 18002, 198, 286, 80493, 855, 19963, 18002, 198, 286, 80493, 19424, 19691, 18002, 198, 286, 80493, 3538, 18002, 198, 286, 80493, 3538, 4452, 18002, 198, 286, 22612, 242, 16991, 1273, 6031, 4452, 18002, 198, 262, 80493, 19424, 9199, 198, 286, 80493, 2193, 18002, 198, 286, 22612, 242, 16991, 3538, 18002, 198, 262, 22612, 242, 16991, 1887, 18002, 198, 144663, 16991, 19502, 198, 262, 22612, 242, 16991, 23789, 14120, 2395, 198, 144663, 16991, 1273, 198, 262, 80493, 10135, 198, 286, 80493, 1304, 2327, 18725, 3288, 198, 286, 80493, 6813, 7197, 198, 286, 80493, 390, 723, 477, 7197, 198, 286, 80493, 1273, 15467, 7197, 198, 286, 80493, 1273, 814, 13659, 7197, 198, 286, 80493, 1273, 25533, 1693, 7197, 198, 286, 80493, 82257, 7197, 198, 286, 22612, 242, 16991, 12439, 7197, 198, 262, 22612, 242, 16991, 55026, 198, 286, 80493, 61945, 21324, 198, 286, 22612, 242, 16991, 79451, 666, 15546, 2395, 198, 144663, 16991, 7375, 198, 262, 80493, 9544, 198, 286, 80493, 6815, 261, 198, 310, 80493, 1887, 18002, 198, 310, 80493, 11540, 18002, 198, 310, 22612, 242, 16991, 6815, 18002, 198, 286, 80493, 18646, 198, 310, 22612, 242, 16991, 1887, 18002, 198, 286, 80493, 19038, 198, 310, 80493, 61681, 67313, 14738, 7197, 198, 310, 22612, 242, 16991, 4194, 49443, 14738, 7197, 198, 286, 80493, 1273, 3848, 198, 310, 22612, 242, 16991, 1887, 18002, 198, 286, 22612, 242, 16991, 41048, 198, 310, 80493, 1099, 198, 394, 80493, 15877, 198, 503, 22612, 242, 16991, 906, 4327, 198, 394, 80493, 5272, 198, 503, 22612, 242, 16991, 906, 2564, 198, 394, 22612, 242, 16991, 6994, 198, 503, 22612, 242, 16991, 906, 2857, 198, 310, 80493, 1887, 18002, 198, 310, 22612, 242, 16991, 3538, 18002, 198, 262, 22612, 242, 16991, 3051, 198, 286, 80493, 2168, 198, 310, 22612, 242, 16991, 2168, 18002, 198, 286, 22612, 242, 16991, 55026, 1314, 18002, 198, 144663, 16991, 28331, 198, 262, 80493, 21483, 2972, 198, 286, 22612, 242, 16991, 2943, 18002, 198, 262, 80493, 5439, 198, 286, 80493, 5439, 18002, 198, 286, 22612, 242, 16991, 2193, 18002, 198, 262, 80493, 2270, 466, 69, 509, 1451, 198, 286, 80493, 2943, 18002, 198, 286, 22612, 242, 16991, 7497, 18002, 198, 262, 80493, 6238, 4314, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 3553, 18002, 198, 286, 22612, 242, 16991, 3553, 4452, 18002, 198, 262, 80493, 14397, 10661, 411, 34790, 198, 286, 80493, 79314, 22773, 18002, 198, 286, 80493, 79314, 22773, 4452, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 1638, 22773, 18002, 198, 286, 80493, 1638, 22773, 4452, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 14397, 10661, 411, 34790, 18002, 198, 286, 22612, 242, 16991, 14397, 10661, 411, 34790, 4452, 18002, 198, 262, 80493, 14397, 4314, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 2205, 18002, 198, 286, 80493, 2205, 4452, 18002, 198, 286, 80493, 20870, 18002, 198, 286, 80493, 20870, 4452, 18002, 198, 286, 80493, 3553, 18002, 198, 286, 22612, 242, 16991, 7497, 18002, 198, 262, 80493, 90982, 2836, 198, 286, 80493, 21483, 18002, 198, 286, 80493, 21483, 4452, 18002, 198, 286, 80493, 2193, 18002, 198, 286, 80493, 37664, 18002, 198, 286, 80493, 2270, 466, 824, 18002, 198, 286, 80493, 2270, 466, 824, 4452, 18002, 198, 286, 80493, 3538, 18002, 198, 286, 22612, 242, 16991, 1273, 6031, 4452, 18002, 198, 262, 22612, 242, 16991, 1887, 18002, 198, 144663, 16991, 12439, 198, 262, 80493, 33394, 198, 286, 80493, 4568, 2015, 18002, 198, 286, 22612, 242, 16991, 4568, 2015, 4452, 18002, 198, 262, 80493, 2699, 746, 1314, 198, 286, 22612, 242, 16991, 2699, 746, 1314, 18002, 198, 262, 80493, 2193, 1314, 198, 286, 80493, 1273, 691, 198, 310, 80493, 2193, 9432, 35, 198, 394, 80493, 2331, 33406, 198, 394, 80493, 5670, 77763, 17, 33406, 198, 394, 22612, 242, 16991, 5670, 33406, 198, 310, 80493, 2193, 9432, 36, 198, 394, 80493, 2331, 33406, 198, 394, 80493, 5670, 77763, 17, 33406, 198, 394, 22612, 242, 16991, 5670, 33406, 198, 310, 80493, 5248, 198, 394, 80493, 2193, 32, 198, 503, 80493, 2331, 33406, 198, 503, 80493, 5670, 77763, 17, 33406, 198, 503, 22612, 242, 16991, 5670, 33406, 198, 394, 80493, 2193, 33, 198, 503, 80493, 2331, 33406, 198, 503, 80493, 5670, 77763, 17, 33406, 198, 503, 22612, 242, 16991, 5670, 33406, 198, 394, 80493, 2193, 34, 198, 503, 80493, 2331, 33406, 198, 503, 80493, 5670, 77763, 17, 33406, 198, 503, 22612, 242, 16991, 5670, 33406, 198, 394, 80493, 2193, 37, 198, 503, 80493, 2331, 33406, 198, 503, 22612, 242, 16991, 5670, 77763, 17, 33406, 198, 394, 22612, 242, 16991, 2193, 38, 198, 503, 80493, 2331, 33406, 198, 503, 80493, 5670, 77763, 17, 33406, 198, 503, 22612, 242, 16991, 5670, 33406, 198, 310, 22612, 242, 16991, 3175, 198, 394, 80493, 2331, 33406, 198, 394, 80493, 23594, 33406, 198, 394, 22612, 242, 16991, 1273, 33406, 198, 286, 80493, 2193, 18002, 198, 286, 22612, 242, 16991, 2193, 4452, 18002, 198, 262, 80493, 7681, 454, 198, 286, 80493, 9873, 88536, 18002, 198, 286, 80493, 9873, 88536, 4452, 18002, 198, 286, 80493, 4568, 2015, 18002, 198, 286, 80493, 4568, 2015, 4452, 18002, 198, 286, 80493, 1681, 11529, 18002, 198, 286, 22612, 242, 16991, 1681, 11529, 4452, 18002, 198, 262, 80493, 53758, 1306, 1314, 198, 286, 80493, 53758, 1306, 1314, 18002, 198, 286, 22612, 242, 16991, 53758, 1306, 1314, 4452, 18002, 198, 262, 80493, 26588, 1314, 198, 286, 80493, 26588, 1314, 18002, 198, 286, 80493, 26588, 1314, 4452, 18002, 198, 286, 22612, 242, 16991, 37664, 18002, 198, 262, 80493, 1848, 1314, 198, 286, 80493, 1848, 1314, 18002, 198, 286, 22612, 242, 16991, 1848, 1314, 4452, 18002, 198, 262, 80493, 5181, 1314, 198, 286, 22612, 242, 16991, 5181, 1314, 18002, 198, 262, 80493, 7013, 198, 286, 22612, 242, 16991, 7013, 18002, 198, 262, 80493, 17364, 198, 286, 80493, 10619, 10841, 18002, 198, 286, 22612, 242, 16991, 10619, 10841, 4452, 18002, 198, 262, 80493, 54320, 628, 321, 198, 286, 80493, 1182, 1847, 18002, 198, 286, 80493, 54320, 628, 321, 18002, 198, 286, 80493, 54320, 628, 321, 4452, 18002, 198, 286, 80493, 55026, 18002, 198, 286, 22612, 242, 16991, 55026, 4452, 18002, 198, 262, 80493, 11446, 198, 286, 80493, 2193, 18002, 198, 286, 22612, 242, 16991, 8844, 18002, 198, 262, 80493, 34679, 2186, 198, 286, 80493, 2415, 18002, 198, 286, 22612, 242, 16991, 2415, 4452, 18002, 198, 262, 80493, 1487, 198, 286, 80493, 1487, 18002, 198, 286, 22612, 242, 16991, 5925, 18002, 198, 262, 80493, 1833, 2141, 198, 286, 80493, 1833, 2141, 18002, 198, 286, 22612, 242, 16991, 1833, 2141, 4452, 18002, 198, 262, 80493, 7860, 1314, 198, 286, 80493, 7860, 1314, 18002, 198, 286, 22612, 242, 16991, 7860, 1314, 4452, 18002, 198, 262, 80493, 4179, 1314, 198, 286, 22612, 242, 16991, 4179, 1314, 18002, 198, 262, 80493, 2643, 1314, 198, 286, 22612, 242, 16991, 2643, 1314, 18002, 198, 262, 80493, 10382, 1314, 198, 286, 22612, 242, 16991, 10382, 1314, 18002, 198, 262, 80493, 25991, 1314, 198, 286, 80493, 60146, 7573, 18002, 198, 286, 80493, 60146, 7573, 4452, 18002, 198, 286, 80493, 25991, 1314, 18002, 198, 286, 22612, 242, 16991, 25991, 1314, 4452, 18002, 198, 262, 80493, 914, 746, 198, 286, 22612, 242, 16991, 914, 746, 18002, 198, 262, 80493, 12811, 1314, 198, 286, 80493, 31532, 18002, 198, 286, 22612, 242, 16991, 31532, 4452, 18002, 198, 262, 80493, 1273, 1314, 198, 286, 22612, 242, 16991, 1273, 1314, 18002, 198, 262, 22612, 242, 16991, 882, 1314, 198, 286, 80493, 9021, 18002, 198, 286, 80493, 9021, 4452, 18002, 198, 286, 80493, 882, 1314, 18002, 198, 286, 22612, 242, 16991, 882, 1314, 4452, 18002, 198, 144663, 16991, 7405, 1192, 198, 144663, 16991, 61945, 21324, 198, 144663, 16991, 23789, 14120, 33936, 198, 144663, 16991, 8502, 11527, 82, 3909, 198, 144798, 16991, 8502, 81094, 3909, 198, 522, 4987, 38283, 1339, 4624, 3988, 510, 27, 23319, 1269, 397, 29870, 4698, 81, 3366, 198, 522, 23319, 1269, 1339, 4624, 70934, 7933, 510, 27, 42909, 2638, 397, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 198, 27, 42909, 2638, 1339, 29019, 18320, 389, 279, 12542, 594, 5944, 323, 8794, 510, 27, 35499, 42682, 397, 2, 27612, 16447, 3366, 220, 73345, 111116, 106379, 101042, 271, 565, 220, 16, 13, 220, 106871, 31905, 101978, 271, 56296, 16447, 3366, 54851, 46944, 334, 33447, 78882, 112896, 72448, 14, 102724, 334, 73345, 3837, 102093, 100751, 40549, 18137, 243, 250, 65101, 9370, 393, 17, 47, 58657, 28291, 1773, 105464, 102226, 104587, 52334, 72448, 3837, 102298, 101213, 48934, 47874, 110195, 3837, 100629, 105896, 106379, 70500, 5373, 85767, 39352, 5373, 104814, 5373, 81705, 33108, 102121, 99788, 1773, 73345, 100176, 110053, 3837, 102298, 20, 24, 22, 18947, 88086, 26898, 3837, 101909, 106888, 112896, 72448, 3407, 565, 220, 17, 13, 220, 46100, 44956, 99706, 99778, 14, 101042, 271, 14374, 51461, 116, 63109, 98380, 33108, 99558, 11622, 26355, 198, 12, 3070, 47, 17, 47, 40549, 18137, 243, 250, 65101, 17177, 28291, 334, 5122, 104210, 6495, 67892, 66521, 237, 96422, 9370, 102202, 100811, 65101, 17177, 28291, 72448, 198, 12, 3070, 44636, 30440, 106375, 33071, 334, 5122, 100143, 104102, 220, 16, 20, 11, 15, 15, 15, 26853, 108, 110293, 9370, 103414, 3837, 102443, 17177, 28291, 100381, 220, 16, 15, 15, 220, 110156, 23404, 198, 12, 3070, 105063, 99718, 100143, 334, 5122, 100143, 99960, 103414, 42192, 99644, 105173, 33108, 101312, 105653, 33447, 78882, 198, 12, 3070, 44636, 107769, 33071, 334, 5122, 42192, 23990, 27442, 105716, 9370, 112896, 106379, 70500, 271, 14374, 89982, 30534, 110569, 102064, 198, 12, 3070, 10850, 334, 5122, 99558, 110569, 102064, 3837, 100751, 55338, 100185, 110195, 198, 12, 3070, 30280, 334, 5122, 100751, 102705, 81705, 33108, 102011, 100037, 21894, 198, 12, 3070, 25287, 13710, 334, 5122, 100751, 102121, 33108, 104004, 100037, 21894, 198, 12, 3070, 56, 31102, 334, 5122, 100751, 85767, 26898, 198, 12, 3070, 29475, 14, 5835, 11295, 1220, 334, 5122, 100751, 118417, 102011, 271, 14374, 91417, 60949, 102724, 33108, 44956, 105537, 198, 12, 3070, 71356, 104516, 334, 5122, 104210, 10130, 14, 82354, 58143, 35926, 91282, 393, 17, 47, 66521, 237, 96422, 198, 12, 3070, 105653, 33447, 78882, 334, 5122, 50, 18, 5373, 38, 6412, 5373, 39, 62266, 5373, 35, 13659, 32112, 10236, 255, 231, 101312, 33447, 78882, 100143, 198, 12, 3070, 74393, 334, 5122, 81772, 9909, 100751, 104603, 108418, 32108, 23083, 12, 3070, 109300, 104001, 13343, 334, 5122, 35, 13659, 5373, 4502, 67, 18137, 249, 228, 12857, 198, 12, 3070, 104814, 104118, 334, 5122, 51, 745, 5373, 44, 18, 6567, 234, 229, 30844, 72448, 198, 12, 3070, 8903, 77128, 334, 5122, 57, 391, 89254, 77835, 32108, 8903, 77128, 198, 12, 3070, 81705, 334, 5122, 10850, 4891, 236, 253, 21287, 81705, 488, 13027, 27764, 271, 565, 220, 18, 13, 220, 73345, 106130, 106379, 100261, 99759, 14, 101042, 271, 14374, 18137, 40419, 99371, 106130, 100166, 198, 13874, 3989, 29870, 4698, 81, 3366, 5894, 144663, 16991, 8315, 14, 1843, 671, 20713, 44054, 93437, 9909, 102121, 18493, 103991, 110293, 23083, 144663, 16991, 6238, 14, 688, 671, 17116, 44054, 93437, 9909, 105223, 105975, 92374, 23083, 144663, 16991, 28331, 14, 260, 671, 40179, 44054, 93437, 9909, 47, 17, 47, 66521, 237, 47872, 31548, 23083, 144663, 16991, 13291, 14, 1843, 671, 32778, 44054, 93437, 9909, 35, 13659, 62579, 49026, 20742, 107736, 23083, 144663, 16991, 1936, 21492, 14, 257, 671, 7854, 12, 1552, 44054, 93437, 9909, 105151, 100261, 99759, 47874, 23083, 144663, 16991, 6200, 14, 310, 671, 51461, 116, 63109, 20074, 100166, 33108, 107018, 198, 144663, 16991, 3051, 14, 1797, 671, 34369, 109, 71743, 44956, 33108, 102011, 198, 144663, 16991, 2193, 14, 688, 671, 18137, 44104, 21596, 26898, 108421, 198, 144663, 16991, 26588, 14, 688, 671, 40549, 18137, 243, 250, 65101, 104004, 26898, 198, 144663, 16991, 10295, 14, 286, 671, 18137, 225, 101, 100463, 19793, 26355, 33108, 100013, 99719, 198, 144663, 16991, 33765, 14, 310, 671, 66374, 62042, 90867, 20742, 198, 144663, 16991, 7375, 14, 1843, 671, 42849, 227, 99262, 102011, 33108, 118417, 198, 144663, 16991, 1273, 14, 310, 671, 18137, 249, 228, 12857, 81705, 198, 144663, 16991, 12439, 14, 1843, 671, 220, 105600, 102011, 44956, 198, 144798, 16991, 68909, 14, 1843, 671, 98313, 66635, 105717, 64429, 198, 13874, 19324, 14374, 91417, 60949, 85767, 26898, 198, 12, 3070, 8078, 1192, 334, 5122, 99960, 100133, 104004, 5373, 35, 13659, 18137, 243, 250, 65101, 104004, 5373, 81705, 33108, 102121, 198, 12, 3070, 1676, 23540, 41466, 334, 5122, 99200, 110195, 105549, 85767, 108421, 198, 12, 3070, 51899, 3663, 5122, 42, 29827, 18137, 225, 101, 100463, 85767, 198, 12, 3070, 28648, 3663, 5122, 109300, 32108, 85767, 271, 14374, 220, 106008, 106130, 98380, 101042, 271, 820, 51461, 116, 63109, 110195, 106130, 198, 12, 3070, 8092, 3663, 5122, 103991, 110293, 101913, 101259, 47874, 3837, 101884, 40549, 62579, 49026, 20742, 107736, 198, 12, 3070, 8611, 3663, 5122, 105223, 105975, 92374, 3837, 105653, 33108, 17177, 28291, 23404, 62262, 198, 12, 3070, 50395, 3663, 5122, 47, 17, 47, 10236, 121, 239, 68065, 102020, 31548, 3837, 39352, 32664, 49567, 92374, 64064, 198, 12, 3070, 22803, 3663, 5122, 52526, 101259, 3837, 54542, 100811, 65101, 83751, 105971, 105004, 6238, 198, 12, 3070, 5834, 21492, 3663, 5122, 105151, 26939, 106208, 9370, 100261, 99759, 47874, 3837, 100143, 99960, 103414, 105173, 271, 820, 80090, 107, 68878, 44956, 106130, 198, 12, 3070, 2153, 3663, 5122, 100185, 20074, 100166, 9909, 45217, 5373, 1731, 6370, 5373, 30888, 1731, 5373, 12175, 1731, 23083, 12, 3070, 2740, 3663, 5122, 101203, 98380, 44956, 9909, 33447, 78882, 105653, 5373, 47, 17, 47, 8908, 108, 225, 26381, 5373, 99722, 101071, 49567, 23083, 12, 3070, 6031, 3663, 5122, 105600, 102011, 32804, 9909, 71356, 5373, 8903, 77128, 5373, 85767, 49567, 27866, 820, 18137, 225, 101, 100463, 33108, 102011, 106130, 198, 12, 3070, 1676, 3663, 5122, 99200, 110195, 85767, 108421, 198, 12, 3070, 28648, 3663, 5122, 109300, 32108, 102121, 26898, 198, 12, 3070, 51899, 3663, 5122, 42, 29827, 18137, 225, 101, 100463, 116996, 198, 12, 3070, 51668, 3663, 5122, 100013, 99719, 33108, 102121, 19793, 26355, 198, 12, 3070, 15918, 3663, 5122, 104650, 102011, 9909, 118417, 5373, 102111, 81705, 49567, 27866, 565, 220, 19, 13, 51461, 116, 63109, 98380, 106510, 102450, 271, 14374, 89982, 30534, 98380, 105539, 198, 16, 13, 3070, 47, 17, 47, 18137, 243, 250, 65101, 17177, 28291, 334, 5122, 104210, 101624, 107898, 36556, 46448, 28029, 9370, 102202, 17177, 28291, 71356, 198, 17, 13, 3070, 42140, 33447, 78882, 105653, 100143, 334, 5122, 50, 18, 5373, 38, 6412, 5373, 39, 62266, 5373, 35, 13659, 32112, 10236, 255, 231, 198, 18, 13, 3070, 99960, 103414, 105173, 334, 5122, 100143, 104210, 104190, 9370, 62945, 64682, 105173, 198, 19, 13, 3070, 44636, 107769, 33071, 70500, 334, 5122, 35926, 100939, 98671, 99658, 86312, 5373, 42192, 23990, 27442, 105716, 198, 20, 13, 3070, 106875, 101048, 334, 5122, 45439, 80090, 107, 68878, 5373, 104925, 100359, 5373, 20074, 111293, 48927, 271, 14374, 6567, 248, 112, 99760, 9370, 5333, 58143, 102705, 107736, 198, 12, 3070, 35, 13659, 32112, 5333, 334, 5122, 114288, 100142, 40549, 62579, 49026, 20742, 107736, 198, 12, 3070, 9230, 25414, 5333, 334, 5122, 99200, 110195, 17881, 104516, 107736, 198, 12, 3070, 47, 17, 47, 66521, 237, 96422, 334, 5122, 104210, 6495, 67892, 43589, 35926, 91282, 101136, 198, 12, 3070, 99722, 101071, 107736, 334, 5122, 104814, 33108, 113308, 107736, 271, 565, 220, 20, 13, 220, 46100, 44956, 100403, 101042, 271, 14374, 91417, 60949, 46100, 106393, 33108, 104559, 271, 820, 51461, 116, 63109, 20074, 100166, 106393, 320, 2153, 53560, 12, 3070, 36339, 18002, 334, 5122, 33145, 17, 20, 21, 6567, 239, 246, 30534, 54542, 3837, 100143, 74393, 105653, 33108, 116951, 32108, 198, 12, 3070, 4059, 466, 824, 18002, 334, 5122, 67892, 34369, 225, 27369, 39352, 3837, 102298, 17177, 34718, 27369, 33108, 48927, 20074, 198, 12, 3070, 16537, 62, 19922, 3346, 334, 5122, 32664, 49567, 92374, 27369, 33108, 102285, 16744, 39352, 198, 12, 3070, 2733, 8296, 18002, 334, 5122, 67892, 220, 27369, 98671, 99658, 43959, 33108, 48927, 271, 820, 393, 17, 47, 8908, 108, 225, 26381, 31548, 106393, 320, 2740, 5523, 48709, 2687, 15222, 53560, 12, 3070, 63122, 18002, 334, 5122, 47, 17, 47, 10236, 121, 239, 68065, 109894, 44091, 39352, 198, 12, 3070, 12389, 18002, 334, 5122, 57621, 102474, 106293, 105359, 100674, 198, 12, 3070, 18274, 3663, 5122, 64064, 107415, 33108, 20074, 107468, 39352, 198, 12, 3070, 5148, 3663, 5122, 32664, 49567, 92374, 64064, 33108, 118376, 54542, 271, 820, 38433, 236, 78882, 105653, 106393, 320, 2740, 70020, 53560, 12, 3070, 13297, 18002, 334, 5122, 103967, 33447, 78882, 105653, 39352, 31548, 198, 12, 3070, 82, 18, 20942, 3663, 5122, 36136, 328, 18, 53497, 246, 99871, 33447, 78882, 198, 12, 3070, 70, 4837, 20942, 3663, 5122, 14444, 14817, 14693, 38433, 236, 78882, 198, 12, 3070, 71, 34378, 20942, 3663, 5122, 39, 25268, 58657, 51827, 28330, 26898, 72448, 33447, 78882, 198, 12, 3070, 29172, 20942, 3663, 5122, 35, 13659, 62579, 49026, 20742, 33447, 78882, 271, 820, 6567, 234, 223, 99379, 32108, 29258, 41321, 106393, 320, 2740, 4322, 4975, 291, 44848, 53560, 12, 3070, 13297, 18002, 334, 5122, 62945, 64682, 88802, 108069, 29258, 41321, 100674, 198, 12, 3070, 4578, 9995, 1693, 3663, 5122, 105151, 105173, 88802, 75117, 31548, 198, 12, 3070, 4934, 1419, 3663, 5122, 20074, 18397, 61443, 88802, 75117, 31548, 271, 14374, 93920, 114, 77835, 100144, 33108, 70500, 99453, 28330, 198, 16, 13, 3070, 57621, 102474, 106379, 334, 5122, 107415, 31548, 37029, 57621, 101353, 54542, 44091, 105359, 198, 17, 13, 3070, 101255, 14224, 32108, 70500, 334, 5122, 33447, 78882, 105653, 101910, 104285, 100144, 33108, 107736, 111372, 198, 18, 13, 3070, 48934, 47874, 106379, 334, 5122, 99200, 110195, 102024, 102121, 3837, 67338, 10130, 16341, 17, 47, 220, 104516, 198, 19, 13, 3070, 35926, 100939, 70500, 334, 5122, 98671, 99658, 86312, 33108, 99722, 101071, 101884, 105716, 100756, 102005, 198, 20, 13, 3070, 17177, 99371, 106379, 334, 5122, 104542, 9370, 111372, 100920, 3837, 45181, 100185, 20074, 100166, 26939, 103923, 104913, 271, 565, 220, 21, 13, 54599, 96808, 100920, 14, 100261, 99759, 271, 14374, 220, 105072, 98380, 198, 16, 13, 3070, 47, 17, 47, 18137, 243, 250, 65101, 17177, 28291, 72448, 1019, 17, 13, 3070, 42140, 33447, 78882, 105653, 100143, 1019, 18, 13, 3070, 99960, 103414, 105173, 39352, 1019, 19, 13, 3070, 104814, 33108, 113308, 72448, 56177, 14374, 220, 106587, 98380, 271, 820, 393, 17, 47, 18137, 243, 250, 65101, 17177, 28291, 72448, 198, 12, 3070, 16810, 220, 47874, 334, 5122, 35, 13659, 62579, 49026, 20742, 107736, 101884, 198, 12, 3070, 13298, 220, 47874, 334, 5122, 105975, 92374, 33108, 20074, 105653, 198, 12, 3070, 31133, 220, 47874, 334, 5122, 32664, 49567, 92374, 102020, 198, 12, 3070, 47, 17, 47, 8908, 108, 225, 26381, 31548, 334, 5122, 64064, 108069, 20074, 107468, 271, 820, 40666, 248, 33447, 78882, 105653, 100143, 198, 12, 3070, 50, 18, 38433, 236, 78882, 334, 5122, 36136, 328, 18, 18137, 249, 228, 12857, 198, 12, 3070, 38, 6412, 38433, 236, 78882, 334, 5122, 14444, 14817, 14693, 18137, 249, 228, 12857, 198, 12, 3070, 39, 62266, 38433, 236, 78882, 334, 5122, 39, 25268, 58657, 51827, 28330, 26898, 72448, 102705, 198, 12, 3070, 15603, 38433, 236, 78882, 334, 5122, 35, 13659, 62579, 49026, 20742, 102705, 198, 12, 3070, 105653, 39352, 31548, 334, 5122, 103967, 105653, 107736, 271, 820, 8908, 71933, 103414, 105173, 39352, 198, 12, 3070, 11066, 12, 1552, 220, 47874, 334, 5122, 105151, 100261, 99759, 39352, 198, 12, 3070, 16219, 220, 47874, 334, 5122, 52526, 101259, 33108, 116817, 198, 12, 3070, 105151, 105173, 334, 5122, 62945, 64682, 105173, 88802, 198, 12, 3070, 20074, 18397, 61443, 334, 5122, 99982, 24360, 20074, 108418, 32108, 271, 820, 74866, 239, 99332, 33108, 113308, 72448, 198, 12, 3070, 99722, 101071, 334, 5122, 110195, 44091, 104814, 198, 12, 3070, 104118, 104412, 334, 5122, 102111, 33108, 37029, 100787, 198, 12, 3070, 8903, 77128, 72448, 334, 5122, 100166, 32108, 8903, 77128, 65577, 198, 12, 3070, 118417, 102011, 334, 5122, 71356, 44091, 118417, 271, 14374, 220, 107049, 98380, 271, 820, 393, 17, 47, 8908, 108, 225, 26381, 31548, 110837, 198, 12, 3070, 64064, 44091, 39352, 334, 5122, 32664, 49567, 92374, 64064, 113509, 198, 12, 3070, 20074, 17177, 28291, 104238, 334, 5122, 101474, 18830, 104747, 5373, 47363, 104238, 198, 12, 3070, 71356, 57621, 54542, 334, 5122, 57621, 100394, 33108, 100030, 198, 12, 3070, 118376, 101136, 334, 5122, 32664, 49567, 92374, 101294, 48927, 271, 820, 53497, 246, 99871, 72448, 110837, 198, 12, 3070, 43815, 100246, 40623, 105653, 334, 5122, 104210, 106208, 105918, 105653, 198, 12, 3070, 99982, 24360, 39352, 334, 5122, 104603, 99982, 24360, 104238, 33108, 104886, 198, 12, 3070, 52526, 105653, 334, 5122, 104875, 52526, 26898, 39352, 198, 12, 3070, 23305, 20074, 39352, 334, 5122, 26898, 23305, 27369, 108418, 32108, 271, 565, 220, 22, 13, 19468, 251, 102569, 106098, 101086, 271, 14374, 18137, 250, 222, 30534, 100700, 104136, 107402, 198, 16, 13, 3070, 47, 17, 47, 10236, 121, 239, 68065, 105318, 334, 5122, 8344, 67892, 66521, 237, 96422, 33108, 101624, 107898, 36556, 46448, 28029, 198, 17, 13, 3070, 35, 13659, 62579, 49026, 20742, 101136, 334, 5122, 100811, 65101, 83751, 72225, 100674, 33108, 5333, 54955, 226, 99453, 198, 18, 13, 3070, 112896, 98671, 99658, 86312, 334, 5122, 118661, 98671, 99658, 33108, 118878, 107101, 198, 19, 13, 3070, 43815, 100246, 40623, 105653, 334, 5122, 104210, 106208, 105918, 99877, 75768, 198, 20, 13, 3070, 57621, 102474, 110569, 334, 5122, 62945, 64682, 57621, 54542, 100144, 271, 14374, 4891, 231, 235, 21596, 100032, 101882, 198, 16, 13, 3070, 10850, 220, 102064, 99896, 334, 5122, 117206, 5373, 115668, 5373, 107736, 70500, 198, 17, 13, 3070, 35, 13659, 4891, 253, 118, 99806, 334, 5122, 109300, 5373, 100811, 65101, 5373, 61689, 20742, 101290, 198, 18, 13, 3070, 71356, 110569, 334, 5122, 9230, 5373, 49896, 5373, 47, 17, 47, 10236, 121, 239, 68065, 198, 19, 13, 3070, 112896, 72448, 334, 5122, 31400, 10236, 238, 228, 67831, 5373, 118661, 5373, 107769, 33071, 198, 20, 13, 3070, 46324, 10236, 36097, 54658, 39352, 334, 5122, 26898, 72448, 5373, 101556, 39352, 5373, 71356, 85767, 271, 565, 220, 23, 13, 220, 46100, 100166, 101042, 271, 14374, 93920, 114, 77835, 104040, 198, 12, 3070, 48934, 47874, 106379, 334, 5122, 102024, 105646, 110195, 198, 12, 3070, 57621, 102474, 106379, 334, 5122, 62945, 64682, 57621, 54542, 198, 12, 3070, 17177, 99371, 106379, 334, 5122, 104542, 9370, 111372, 100920, 271, 14374, 51461, 116, 63109, 110195, 100145, 198, 13874, 3989, 16810, 47464, 51018, 40179, 47464, 51018, 17116, 198, 220, 77854, 286, 77854, 286, 77854, 198, 38878, 11397, 393, 17, 47, 8141, 198, 220, 77854, 198, 5793, 55260, 198, 13874, 19324, 14374, 89982, 30534, 21515, 33108, 106393, 198, 12, 3070, 38878, 334, 5122, 47, 17, 47, 10236, 121, 239, 68065, 107415, 100185, 198, 12, 3070, 21839, 334, 5122, 20074, 107468, 39352, 198, 12, 3070, 29699, 58298, 334, 5122, 105653, 33447, 78882, 39352, 198, 12, 3070, 67892, 42502, 334, 5122, 67892, 69594, 39352, 198, 12, 3070, 30888, 1731, 14, 30888, 1972, 334, 5122, 32664, 49567, 92374, 27369, 271, 565, 220, 24, 13, 62262, 88653, 101042, 271, 14374, 62262, 76837, 105946, 271, 820, 18137, 243, 250, 65101, 62189, 102054, 198, 13874, 3989, 35, 13659, 8423, 11397, 20713, 11397, 40179, 320, 45912, 32664, 49567, 92374, 340, 1797, 77854, 198, 16810, 11397, 17116, 14, 10197, 388, 320, 62189, 20074, 17177, 34718, 340, 1797, 77854, 198, 16810, 11397, 8774, 14693, 320, 99982, 24360, 340, 1797, 77854, 198, 35, 13659, 8423, 320, 31526, 100811, 65101, 20074, 340, 13874, 19324, 820, 18137, 243, 250, 65101, 52526, 102054, 198, 13874, 3989, 35, 13659, 8423, 11397, 32778, 11397, 17116, 320, 105653, 23404, 340, 503, 77854, 198, 394, 7854, 12, 1552, 320, 105653, 105151, 100261, 99759, 340, 503, 77854, 198, 394, 55260, 14693, 320, 108418, 32108, 340, 13874, 19324, 14374, 91417, 60949, 20074, 104949, 198, 12, 3070, 45217, 334, 5122, 33145, 17, 20, 21, 6567, 239, 246, 30534, 3837, 43815, 100246, 40623, 105549, 198, 12, 3070, 12175, 1731, 334, 5122, 67892, 34369, 225, 27369, 3837, 102298, 17177, 34718, 33108, 48927, 27369, 198, 12, 3070, 30888, 1731, 334, 5122, 32664, 49567, 92374, 27369, 3837, 102298, 6790, 5373, 78882, 39426, 5373, 44091, 198, 12, 3070, 1731, 6370, 334, 5122, 67892, 220, 27369, 98671, 99658, 3837, 47, 17, 47, 10236, 121, 239, 68065, 106918, 271, 565, 220, 16, 15, 13, 18137, 249, 228, 12857, 33108, 106375, 27442, 102450, 271, 14374, 55059, 240, 14224, 72448, 198, 12, 3070, 33447, 78882, 105653, 101255, 14224, 334, 5122, 67338, 104285, 100144, 61689, 100676, 105653, 33447, 78882, 198, 12, 3070, 109300, 104001, 13343, 101255, 14224, 334, 5122, 100143, 40549, 58143, 9678, 67, 198, 12, 3070, 71356, 57621, 101255, 14224, 334, 5122, 30440, 106375, 9370, 57621, 54542, 100674, 271, 14374, 220, 106961, 72448, 102705, 198, 12, 3070, 42, 29827, 18137, 249, 228, 12857, 334, 5122, 67338, 62042, 90867, 20742, 102121, 198, 12, 3070, 104814, 72448, 102705, 334, 5122, 100143, 386, 18, 5373, 16635, 35, 6567, 234, 229, 30844, 198, 12, 3070, 8903, 77128, 72448, 102705, 334, 5122, 100166, 32108, 8903, 77128, 66017, 198, 12, 3070, 11237, 14, 6484, 18137, 249, 228, 12857, 334, 5122, 35, 13659, 18137, 243, 250, 65101, 104004, 33108, 90447, 271, 565, 220, 16, 16, 13, 220, 105537, 101042, 271, 14374, 40666, 244, 32948, 105537, 198, 12, 3070, 105653, 47874, 334, 5122, 36136, 328, 18, 5373, 14444, 14817, 14693, 5373, 39, 62266, 198, 12, 3070, 74393, 334, 5122, 81772, 9909, 104603, 108418, 32108, 23083, 12, 3070, 109300, 104001, 13343, 334, 5122, 35, 13659, 5373, 4502, 67, 198, 12, 3070, 104814, 72448, 334, 5122, 44, 18, 5373, 16635, 35, 198, 12, 3070, 71356, 44956, 334, 5122, 100142, 5994, 10236, 121, 239, 68065, 44956, 271, 14374, 68739, 227, 32948, 110195, 105537, 198, 12, 3070, 5386, 11397, 5688, 334, 5122, 100185, 20074, 100166, 99250, 55338, 44956, 37029, 198, 12, 3070, 9194, 11397, 17954, 334, 5122, 44956, 105537, 105600, 102011, 32804, 198, 12, 3070, 10443, 11397, 5688, 334, 5122, 99200, 110195, 105537, 101203, 44956, 198, 12, 3070, 18200, 11397, 14563, 82, 334, 5122, 81705, 105537, 105717, 64429, 271, 565, 220, 16, 17, 13, 50042, 99257, 88653, 100261, 99759, 271, 14374, 4891, 116, 116, 88970, 20002, 102122, 271, 820, 81947, 28291, 28946, 102122, 198, 16, 13, 3070, 104603, 100013, 334, 5122, 37029, 3483, 18855, 38433, 107, 27733, 104603, 99719, 198, 17, 13, 3070, 100811, 65101, 104004, 334, 5122, 110885, 100811, 65101, 26939, 16447, 3366, 18137, 249, 228, 99430, 198, 18, 13, 3070, 100811, 65101, 102121, 334, 5122, 45181, 16447, 3366, 6567, 233, 231, 18158, 100811, 65101, 102121, 99892, 271, 820, 32181, 238, 99479, 102122, 198, 16, 13, 3070, 103414, 102121, 334, 5122, 37029, 62042, 73562, 66374, 64118, 102121, 198, 17, 13, 3070, 104814, 113308, 334, 5122, 67338, 104118, 33108, 8903, 77128, 104814, 72448, 44091, 198, 18, 13, 3070, 105716, 105853, 334, 5122, 37029, 118417, 102011, 101042, 71356, 44091, 271, 14374, 91417, 60949, 105333, 27442, 198, 12, 3070, 16810, 334, 5122, 35, 13659, 41479, 95, 17523, 78882, 104396, 105333, 198, 12, 3070, 16219, 334, 5122, 100811, 65101, 110885, 9370, 105333, 198, 12, 3070, 39, 23162, 90867, 20742, 334, 5122, 42, 29827, 18137, 225, 101, 100463, 105333, 198, 12, 3070, 85767, 26898, 334, 5122, 72448, 85767, 9370, 105333, 271, 565, 220, 16, 18, 13, 53040, 100104, 100166, 100367, 271, 14374, 4891, 119, 118, 96422, 9370, 111116, 106379, 271, 820, 220, 16, 13, 220, 73345, 99706, 99778, 198, 12, 220, 73345, 100157, 33108, 100185, 100162, 198, 12, 93920, 114, 77835, 113608, 33108, 117200, 198, 12, 88940, 104, 94299, 55286, 105866, 198, 12, 220, 106961, 104520, 9370, 104877, 271, 820, 220, 17, 13, 96155, 222, 99216, 103642, 33108, 105537, 198, 12, 84238, 244, 38507, 102064, 33108, 102724, 50404, 198, 12, 40666, 244, 32948, 105537, 33108, 102705, 101882, 198, 12, 10236, 36097, 54658, 101882, 33108, 114288, 33071, 271, 820, 220, 18, 13, 51461, 116, 63109, 106379, 70500, 198, 12, 43614, 112, 31914, 106379, 28029, 33108, 110195, 100145, 198, 12, 393, 17, 47, 10236, 121, 239, 68065, 70500, 105318, 198, 12, 62262, 88653, 33108, 100359, 88653, 101042, 198, 12, 18137, 40419, 107769, 105178, 36629, 28726, 70500, 271, 820, 220, 19, 13, 51461, 116, 63109, 110195, 118801, 271, 67331, 220, 19, 13, 16, 20713, 44054, 93437, 198, 12, 20713, 93920, 114, 77835, 33108, 104559, 198, 12, 40549, 62579, 49026, 20742, 107736, 101884, 198, 12, 393, 17, 47, 8908, 108, 225, 26381, 31548, 70500, 198, 12, 220, 104603, 105653, 39352, 198, 12, 18137, 44104, 21596, 32665, 118801, 271, 67331, 220, 19, 13, 17, 17116, 44054, 93437, 198, 12, 17116, 93920, 114, 77835, 33108, 105975, 92374, 100780, 198, 12, 4891, 241, 230, 99658, 86312, 33108, 118878, 107101, 198, 12, 38433, 236, 78882, 105653, 102705, 198, 12, 62262, 52526, 33108, 62189, 54542, 198, 12, 18137, 44104, 21596, 32665, 118801, 271, 67331, 220, 19, 13, 18, 40179, 44054, 93437, 198, 12, 40179, 93920, 114, 77835, 33108, 102020, 98380, 198, 12, 69162, 49567, 92374, 39352, 198, 12, 41479, 96, 51827, 101136, 101884, 198, 12, 34369, 225, 110042, 198, 12, 18137, 44104, 21596, 32665, 118801, 271, 67331, 220, 19, 13, 19, 32778, 44054, 93437, 198, 12, 32778, 93920, 114, 77835, 33108, 101259, 98380, 198, 12, 18137, 243, 250, 65101, 52526, 116817, 198, 12, 7854, 12, 1552, 18137, 249, 228, 12857, 198, 12, 18137, 95, 226, 99259, 33108, 98841, 18158, 100674, 198, 12, 18137, 44104, 21596, 32665, 118801, 271, 67331, 220, 19, 13, 20, 7854, 12, 1552, 44054, 93437, 198, 12, 7854, 12, 1552, 93920, 114, 77835, 33108, 105151, 39352, 198, 12, 51461, 229, 61755, 26939, 106208, 100261, 99759, 198, 12, 8908, 71933, 103414, 105173, 100674, 198, 12, 220, 105537, 106637, 31548, 198, 12, 18137, 44104, 21596, 32665, 118801, 271, 820, 220, 20, 13, 51461, 116, 63109, 44956, 33108, 102011, 271, 67331, 220, 20, 13, 16, 51461, 116, 63109, 20074, 100166, 320, 2153, 53560, 12, 53289, 6567, 239, 246, 30534, 54542, 198, 12, 15819, 1731, 34369, 225, 27369, 39352, 198, 12, 45147, 1731, 69162, 49567, 92374, 27369, 198, 12, 13074, 6370, 220, 27369, 98671, 99658, 271, 67331, 220, 20, 13, 17, 393, 17, 47, 8908, 108, 225, 26381, 31548, 320, 2740, 5523, 48709, 2687, 15222, 53560, 12, 8908, 108, 225, 26381, 31548, 106379, 70500, 198, 12, 220, 57621, 102474, 104949, 198, 12, 32181, 252, 29077, 44091, 39352, 198, 12, 62262, 17177, 28291, 104238, 198, 12, 10236, 121, 239, 68065, 57621, 54542, 271, 67331, 220, 20, 13, 18, 38433, 236, 78882, 105653, 72448, 320, 2740, 70020, 53560, 12, 53497, 246, 99871, 111372, 33108, 107736, 70500, 198, 12, 328, 18, 38433, 236, 78882, 101884, 198, 12, 479, 6412, 38433, 236, 78882, 101884, 198, 12, 472, 62266, 38433, 236, 78882, 101884, 198, 12, 32112, 38433, 236, 78882, 101884, 198, 12, 53497, 246, 99871, 39352, 31548, 271, 67331, 220, 20, 13, 19, 6567, 234, 223, 99379, 32108, 29258, 41321, 72448, 320, 2740, 4322, 4975, 291, 44848, 53560, 12, 52506, 224, 64682, 88802, 39352, 198, 12, 51461, 229, 61755, 105173, 100674, 198, 12, 62262, 18397, 61443, 54542, 198, 12, 93178, 247, 29056, 54542, 33108, 29258, 41321, 104238, 271, 67331, 220, 20, 13, 20, 4891, 223, 98, 99446, 101071, 72448, 320, 2740, 14, 12120, 2028, 53560, 12, 4891, 223, 98, 99446, 101071, 106379, 198, 12, 89982, 27733, 33108, 107285, 101071, 198, 12, 32181, 229, 102980, 31548, 33108, 104814, 198, 12, 10236, 39366, 39352, 271, 67331, 220, 20, 13, 21, 53497, 246, 99871, 72448, 320, 2740, 31320, 53560, 12, 68739, 227, 36629, 100246, 40623, 105653, 198, 12, 84238, 241, 24360, 39352, 104238, 198, 12, 220, 52526, 104875, 105653, 198, 12, 34369, 225, 20074, 108418, 32108, 198, 12, 97259, 227, 21887, 33108, 101999, 271, 820, 220, 21, 13, 18137, 44104, 21596, 39352, 198, 12, 18137, 44104, 21596, 26898, 100166, 33108, 117206, 198, 12, 38433, 226, 110195, 85767, 118801, 198, 12, 10236, 236, 107, 99279, 105149, 85767, 198, 12, 18137, 44104, 21596, 48927, 33108, 47363, 25511, 198, 12, 54599, 101, 35243, 85767, 50007, 271, 820, 220, 22, 13, 18137, 225, 101, 100463, 33108, 113308, 271, 67331, 220, 22, 13, 16, 220, 104603, 100013, 99719, 198, 12, 6040, 18855, 85658, 105866, 198, 12, 40549, 1198, 2900, 18137, 225, 101, 100463, 198, 12, 81947, 28291, 102011, 33108, 110760, 271, 67331, 220, 22, 13, 17, 58263, 51232, 99719, 102121, 198, 12, 66374, 18137, 225, 101, 100463, 105866, 198, 12, 62042, 90867, 20742, 85767, 198, 12, 8908, 113, 226, 37984, 100367, 33108, 47872, 90172, 198, 12, 41479, 231, 35987, 85767, 33108, 102179, 100419, 271, 67331, 220, 22, 13, 18, 74866, 239, 99332, 33108, 113308, 198, 12, 6567, 234, 229, 30844, 104412, 33108, 104814, 198, 12, 75402, 77128, 108069, 101042, 198, 12, 4891, 223, 98, 99446, 101071, 33108, 57555, 99511, 198, 12, 90476, 100, 26232, 47872, 90172, 105866, 198, 12, 43614, 227, 99884, 105853, 108750, 271, 820, 220, 23, 13, 5333, 26853, 224, 77598, 111116, 198, 12, 40549, 32112, 5333, 34369, 120, 36629, 33071, 198, 12, 44054, 93437, 17881, 10130, 5333, 198, 12, 393, 17, 47, 66521, 237, 96422, 101931, 198, 12, 4891, 223, 98, 99446, 101071, 78882, 27442, 198, 12, 10236, 106, 94, 21887, 33108, 85767, 5333, 271, 820, 220, 24, 13, 62262, 104949, 33108, 105653, 198, 12, 51461, 116, 63109, 20074, 100166, 91282, 198, 12, 62262, 44956, 100144, 70500, 198, 12, 53497, 246, 99871, 68805, 101931, 198, 12, 62262, 113274, 33108, 71109, 39352, 271, 820, 220, 16, 15, 13, 41479, 231, 35987, 33071, 70500, 198, 12, 33424, 97, 33477, 33108, 102204, 100674, 198, 12, 41654, 18137, 44104, 21596, 33108, 104759, 39352, 198, 12, 10236, 121, 239, 68065, 99464, 33108, 104925, 100359, 198, 12, 62262, 111293, 100153, 198, 12, 41479, 231, 35987, 102179, 100419, 271, 820, 220, 16, 16, 13, 90476, 100, 26232, 33108, 106375, 33071, 198, 12, 90476, 100, 26232, 109371, 81705, 198, 12, 46750, 102, 76313, 33071, 101042, 198, 12, 41479, 117, 32757, 100367, 105866, 198, 12, 90476, 100, 26232, 47872, 90172, 101898, 198, 12, 8908, 112, 253, 27366, 81705, 39907, 271, 820, 220, 16, 17, 13, 18137, 249, 228, 12857, 33108, 106375, 198, 12, 38433, 236, 78882, 105653, 101255, 14224, 100013, 198, 12, 41479, 117, 31548, 104001, 13343, 102705, 198, 12, 74866, 239, 99332, 72448, 102705, 198, 12, 20694, 14, 6484, 98313, 223, 52510, 43268, 102705, 198, 12, 50331, 113699, 102011, 102705, 271, 820, 220, 16, 18, 13, 98313, 66635, 104238, 198, 12, 66521, 243, 23305, 81705, 105866, 198, 12, 18137, 249, 228, 12857, 81705, 102724, 198, 12, 90476, 100, 26232, 81705, 39907, 198, 12, 10236, 104, 107, 26939, 78882, 81705, 198, 12, 98313, 66635, 20074, 39352, 271, 820, 220, 16, 19, 13, 83002, 98, 76813, 33108, 102556, 74220, 198, 12, 26853, 107, 57452, 32108, 102011, 37029, 198, 12, 90476, 100, 26232, 101042, 102011, 198, 12, 62262, 113274, 102011, 198, 12, 18137, 44104, 21596, 48927, 102011, 198, 12, 43614, 227, 99884, 105262, 102011, 271, 820, 220, 16, 20, 13, 81947, 28291, 105866, 198, 12, 220, 46100, 102007, 105866, 198, 12, 81947, 28291, 99719, 104870, 198, 12, 220, 46100, 101931, 33108, 102179, 100419, 198, 12, 8908, 108, 225, 41321, 33108, 81705, 39907, 198, 12, 69425, 51827, 102054, 271, 820, 220, 16, 21, 13, 43614, 227, 99884, 105853, 33108, 101536, 86119, 198, 12, 4891, 116, 116, 88970, 86119, 106185, 198, 12, 43614, 227, 99884, 105853, 102054, 198, 12, 93178, 247, 29056, 46100, 101275, 198, 12, 90476, 100, 26232, 86119, 105262, 198, 12, 10236, 121, 239, 68065, 86119, 105853, 271, 820, 220, 16, 22, 13, 64388, 230, 21894, 100022, 33108, 113274, 198, 12, 64388, 230, 21894, 90447, 66394, 198, 12, 38433, 239, 33447, 114288, 33071, 198, 12, 32181, 223, 59534, 105866, 198, 12, 52506, 225, 11622, 98380, 198, 12, 66521, 229, 52334, 101898, 271, 565, 220, 16, 19, 13, 6567, 118, 238, 46100, 9909, 105537, 26898, 7552, 102802, 271, 14374, 220, 73345, 99706, 99778, 198, 12, 508, 54675, 21324, 9533, 54675, 21324, 340, 12, 508, 8078, 1192, 9533, 8078, 1192, 692, 14374, 51461, 116, 63109, 110195, 271, 820, 20713, 44054, 93437, 198, 12, 508, 8092, 15351, 18002, 9533, 8092, 15351, 18002, 340, 12, 508, 8092, 83033, 83033, 18002, 9533, 8092, 83033, 83033, 18002, 340, 12, 508, 8092, 14, 8092, 4030, 37255, 18002, 9533, 8092, 14, 8092, 4030, 37255, 18002, 340, 12, 508, 8092, 14, 8092, 2972, 25085, 18002, 9533, 8092, 14, 8092, 2972, 25085, 18002, 692, 820, 17116, 44054, 93437, 2303, 12, 508, 8611, 15351, 18002, 9533, 8611, 15351, 18002, 340, 12, 508, 8611, 83033, 83033, 18002, 9533, 8611, 83033, 83033, 18002, 340, 12, 508, 8611, 34827, 4030, 37255, 18002, 9533, 8611, 34827, 4030, 37255, 18002, 340, 12, 508, 8611, 34827, 2972, 25085, 18002, 9533, 8611, 34827, 2972, 25085, 18002, 692, 820, 40179, 44054, 93437, 198, 12, 508, 50395, 15351, 18002, 9533, 50395, 15351, 18002, 340, 12, 508, 50395, 83033, 83033, 18002, 9533, 50395, 83033, 83033, 18002, 340, 12, 508, 50395, 14, 13131, 388, 2836, 37255, 18002, 9533, 50395, 14, 13131, 388, 2836, 37255, 18002, 340, 12, 508, 50395, 14, 65512, 2972, 25085, 18002, 9533, 50395, 14, 65512, 2972, 25085, 18002, 692, 820, 32778, 44054, 93437, 198, 12, 508, 22803, 15351, 18002, 9533, 22803, 15351, 18002, 340, 12, 508, 22803, 83033, 83033, 18002, 9533, 22803, 83033, 83033, 18002, 340, 12, 508, 22803, 18008, 4130, 4030, 37255, 18002, 9533, 22803, 18008, 4130, 4030, 37255, 18002, 692, 820, 7854, 12, 1552, 44054, 93437, 198, 12, 508, 5834, 21492, 15351, 18002, 9533, 5834, 21492, 15351, 18002, 340, 12, 508, 5834, 21492, 83033, 83033, 18002, 9533, 5834, 21492, 83033, 83033, 18002, 340, 12, 508, 5834, 21492, 76196, 4030, 37255, 18002, 9533, 5834, 21492, 76196, 4030, 37255, 18002, 340, 12, 508, 5834, 21492, 76196, 4314, 31320, 18002, 9533, 5834, 21492, 76196, 4314, 31320, 18002, 692, 14374, 51461, 116, 63109, 20074, 100166, 198, 12, 508, 2153, 3446, 15153, 18002, 9533, 2153, 3446, 15153, 18002, 340, 12, 508, 2153, 90228, 466, 824, 18002, 9533, 2153, 90228, 466, 824, 18002, 340, 12, 508, 2153, 14, 16537, 3109, 18002, 9533, 2153, 14, 16537, 3109, 18002, 340, 12, 508, 2153, 54976, 8296, 18002, 9533, 2153, 54976, 8296, 18002, 340, 12, 508, 2153, 14, 16537, 8467, 18002, 9533, 2153, 14, 16537, 8467, 18002, 692, 14374, 393, 17, 47, 8908, 108, 225, 26381, 31548, 198, 12, 508, 2740, 5523, 48709, 2687, 15222, 2687, 15222, 18002, 9533, 2740, 5523, 48709, 2687, 15222, 2687, 15222, 18002, 340, 12, 508, 2740, 5523, 48709, 2687, 15222, 42764, 18002, 9533, 2740, 5523, 48709, 2687, 15222, 42764, 18002, 340, 12, 508, 2740, 5523, 48709, 2687, 15222, 63796, 18002, 9533, 2740, 5523, 48709, 2687, 15222, 63796, 18002, 340, 12, 508, 2740, 5523, 48709, 2687, 15222, 41510, 3400, 41510, 3400, 261, 18002, 9533, 2740, 5523, 48709, 2687, 15222, 41510, 3400, 41510, 3400, 261, 18002, 340, 12, 508, 2740, 5523, 48709, 2687, 15222, 14, 5148, 14, 5148, 18002, 9533, 2740, 5523, 48709, 2687, 15222, 14, 5148, 14, 5148, 18002, 692, 14374, 38433, 236, 78882, 105653, 72448, 198, 12, 508, 2740, 70020, 14, 13297, 18002, 9533, 2740, 70020, 14, 13297, 18002, 340, 12, 508, 2740, 70020, 14730, 18002, 9533, 2740, 70020, 14730, 18002, 340, 12, 508, 2740, 70020, 25085, 18002, 9533, 2740, 70020, 25085, 18002, 340, 12, 508, 2740, 70020, 2687, 18, 20942, 25085, 18002, 9533, 2740, 70020, 2687, 18, 20942, 25085, 18002, 340, 12, 508, 2740, 70020, 4846, 4837, 20942, 25085, 18002, 9533, 2740, 70020, 4846, 4837, 20942, 25085, 18002, 340, 12, 508, 2740, 70020, 7530, 34378, 20942, 25085, 18002, 9533, 2740, 70020, 7530, 34378, 20942, 25085, 18002, 340, 12, 508, 2740, 70020, 14, 29172, 20942, 34827, 2972, 18002, 9533, 2740, 70020, 14, 29172, 20942, 34827, 2972, 18002, 692, 14374, 53497, 246, 99871, 72448, 198, 12, 508, 2740, 31320, 80591, 14809, 18002, 9533, 2740, 31320, 80591, 14809, 18002, 340, 12, 508, 2740, 31320, 62094, 14809, 18002, 9533, 2740, 31320, 62094, 14809, 18002, 340, 12, 508, 2740, 31320, 37173, 14809, 18002, 9533, 2740, 31320, 37173, 14809, 18002, 340, 12, 508, 2740, 31320, 26090, 23903, 14809, 18002, 9533, 2740, 31320, 26090, 23903, 14809, 18002, 340, 12, 508, 2740, 31320, 3183, 7603, 3183, 7603, 18002, 9533, 2740, 31320, 3183, 7603, 3183, 7603, 18002, 692, 14374, 6567, 234, 223, 99379, 32108, 29258, 41321, 72448, 198, 12, 508, 2740, 4322, 4975, 291, 44848, 14, 13297, 18002, 9533, 2740, 4322, 4975, 291, 44848, 14, 13297, 18002, 340, 12, 508, 2740, 4322, 4975, 291, 44848, 76196, 9995, 1693, 14, 80787, 18002, 9533, 2740, 4322, 4975, 291, 44848, 76196, 9995, 1693, 14, 80787, 18002, 340, 12, 508, 2740, 4322, 4975, 291, 44848, 64264, 1419, 14, 80787, 18002, 9533, 2740, 4322, 4975, 291, 44848, 64264, 1419, 14, 80787, 18002, 692, 14374, 4891, 223, 98, 99446, 101071, 72448, 198, 12, 508, 2740, 14, 12120, 2028, 14, 32225, 18002, 9533, 2740, 14, 12120, 2028, 14, 32225, 18002, 340, 12, 508, 2740, 14, 12120, 2028, 46619, 261, 18002, 9533, 2740, 14, 12120, 2028, 46619, 261, 18002, 340, 12, 508, 2740, 14, 12120, 2028, 63524, 18002, 9533, 2740, 14, 12120, 2028, 63524, 18002, 692, 14374, 10236, 121, 239, 68065, 57621, 54542, 198, 12, 508, 2740, 5523, 48709, 38065, 49710, 440, 684, 14, 58912, 18002, 9533, 2740, 5523, 48709, 38065, 49710, 440, 684, 14, 58912, 18002, 340, 12, 508, 2740, 5523, 48709, 38065, 49710, 440, 684, 42764, 18002, 9533, 2740, 5523, 48709, 38065, 49710, 440, 684, 42764, 18002, 692, 14374, 18137, 44104, 21596, 39352, 198, 12, 508, 1676, 14, 8092, 26090, 33406, 9533, 1676, 14, 8092, 26090, 33406, 340, 12, 508, 1676, 14, 8611, 26090, 33406, 9533, 1676, 14, 8611, 26090, 33406, 340, 12, 508, 1676, 21485, 9683, 26090, 33406, 9533, 1676, 21485, 9683, 26090, 33406, 340, 12, 508, 1676, 18008, 4130, 26090, 33406, 9533, 1676, 18008, 4130, 26090, 33406, 340, 12, 508, 1676, 30593, 21492, 26090, 33406, 9533, 1676, 30593, 21492, 26090, 33406, 692, 14374, 18137, 225, 101, 100463, 33108, 113308, 198, 12, 508, 28648, 14, 8092, 14953, 13659, 1192, 9533, 28648, 14, 8092, 14953, 13659, 1192, 340, 12, 508, 28648, 14, 8611, 14953, 13659, 1192, 9533, 28648, 14, 8611, 14953, 13659, 1192, 340, 12, 508, 28648, 21485, 9683, 14953, 13659, 1192, 9533, 28648, 21485, 9683, 14953, 13659, 1192, 340, 12, 508, 51899, 14, 14488, 33406, 9533, 51899, 14, 14488, 33406, 340, 12, 508, 51899, 96985, 33406, 9533, 51899, 96985, 33406, 340, 12, 508, 51668, 35061, 18855, 14, 54675, 21324, 9533, 51668, 35061, 18855, 14, 54675, 21324, 340, 12, 508, 51668, 14109, 23, 82, 14, 54675, 21324, 9533, 51668, 14109, 23, 82, 14, 54675, 21324, 692, 14374, 83002, 98, 76813, 33108, 102556, 74220, 198, 12, 508, 15918, 8749, 14, 88981, 15351, 18002, 9533, 15918, 8749, 14, 88981, 15351, 18002, 340, 12, 508, 15918, 8749, 4322, 617, 261, 15351, 18002, 9533, 15918, 8749, 4322, 617, 261, 15351, 18002, 340, 12, 508, 15918, 8749, 10758, 1078, 15351, 18002, 9533, 15918, 8749, 10758, 1078, 15351, 18002, 692, 14374, 98313, 66635, 198, 12, 508, 1944, 23266, 31236, 723, 477, 7197, 9533, 1944, 23266, 31236, 723, 477, 7197, 340, 12, 508, 1944, 23266, 12697, 15467, 7197, 9533, 1944, 23266, 12697, 15467, 7197, 340, 12, 508, 1944, 23266, 12697, 814, 13659, 7197, 9533, 1944, 23266, 12697, 814, 13659, 7197, 692, 14374, 220, 105600, 102011, 44956, 198, 12, 508, 6031, 14730, 1314, 14730, 18002, 9533, 6031, 14730, 1314, 14730, 18002, 340, 12, 508, 6031, 14, 96336, 628, 321, 14, 96336, 628, 321, 18002, 9533, 6031, 14, 96336, 628, 321, 14, 96336, 628, 321, 18002, 340, 12, 508, 6031, 19413, 19413, 18002, 9533, 6031, 19413, 19413, 18002, 340, 12, 508, 6031, 38065, 1314, 38065, 1314, 18002, 9533, 6031, 38065, 1314, 38065, 1314, 18002, 692, 99487, 111116, 106379, 114369, 27612, 16447, 3366, 220, 73345, 105679, 100185, 98380, 5373, 99361, 104449, 33108, 37029, 102122, 3837, 17714, 99604, 100920, 9370, 20002, 104257, 105896, 112872, 1773, 111116, 100166, 99929, 100662, 34187, 104913, 104542, 33071, 3837, 99518, 103944, 34187, 99361, 102217, 3837, 100006, 101929, 45181, 84607, 102569, 26939, 104112, 100013, 106017, 99604, 100354, 8997, 522, 35499, 42682, 1339, 27, 14172, 13429, 287, 397, 2610, 614, 7375, 518, 697, 33445, 311, 11625, 279, 10822, 3383, 13, 11112, 1493, 5601, 8826, 5392, 6738, 510, 16, 13, 67414, 1795, 279, 5392, 1618, 10802, 6896, 438, 5189, 323, 1281, 2704, 311, 3410, 678, 5871, 5029, 624, 17, 13, 576, 10435, 1231, 5785, 7375, 429, 525, 902, 5021, 2500, 13, 55025, 1618, 7375, 429, 525, 537, 20975, 3897, 624, 18, 13, 3155, 537, 990, 1140, 4334, 18822, 311, 6851, 6220, 476, 1034, 1995, 429, 374, 2669, 3897, 304, 279, 2390, 5944, 624, 522, 14172, 13429, 287, 1339, 7771, 5795, 374, 311, 9245, 264, 2390, 18906, 9705, 12626, 14257, 504, 15817, 6358, 315, 279, 2038, 3152, 11, 61945, 11, 323, 12613, 7236, 13, 576, 5944, 1265, 8683, 438, 279, 16266, 369, 264, 9705, 3910, 11, 53829, 311, 2176, 46850, 323, 10321, 13402, 19178, 382, 36850, 7354, 510, 16, 13, 28596, 10816, 510, 256, 7854, 264, 7299, 11591, 9705, 28922, 429, 25963, 279, 2390, 594, 3692, 7321, 198, 17, 13, 75938, 59170, 510, 256, 7405, 2704, 429, 23759, 17189, 448, 678, 10007, 9705, 8502, 3685, 198, 18, 13, 9258, 23470, 510, 256, 15042, 1590, 3059, 304, 279, 2567, 4718, 3561, 271, 8420, 525, 279, 8502, 369, 279, 9705, 5944, 9471, 510, 16, 13, 12260, 1172, 2924, 14158, 429, 7866, 311, 5042, 2038, 3152, 6813, 11, 3516, 11, 323, 11537, 4419, 304, 279, 2390, 624, 17, 13, 28596, 7321, 1265, 1795, 2390, 594, 19819, 6396, 323, 10306, 42463, 6193, 287, 624, 18, 13, 45945, 2449, 1969, 2432, 2390, 594, 2038, 3152, 34918, 323, 990, 12966, 34948, 44493, 624, 19, 13, 29734, 2009, 5333, 9705, 311, 3421, 678, 584, 24099, 323, 2924, 14887, 28703, 624, 20, 13, 758, 279, 2213, 11, 6832, 32724, 1265, 1191, 448, 6770, 18940, 11, 1221, 5098, 311, 10847, 13347, 624, 21, 13, 30846, 1550, 11591, 916, 5072, 448, 11682, 5785, 9705, 624, 22, 13, 29734, 14158, 369, 3709, 35812, 8474, 11, 13713, 11221, 323, 6770, 10431, 10295, 624, 23, 13, 39565, 12235, 14158, 369, 1817, 4565, 11, 1186, 91840, 11, 4625, 10802, 11, 5333, 323, 2473, 624, 24, 13, 8883, 10191, 1969, 2924, 678, 4583, 4419, 323, 1186, 91840, 7289, 624, 16, 15, 13, 21159, 2213, 1741, 438, 68671, 27193, 323, 6203, 11591, 10431, 1265, 387, 5230, 304, 8311, 1992, 624, 16, 16, 13, 9177, 6546, 11, 48041, 11, 323, 8894, 3501, 624, 16, 17, 13, 17207, 3684, 1265, 387, 304, 19819, 7321, 11, 4135, 42156, 12624, 624, 16, 18, 13, 1752, 1817, 3772, 11, 10542, 323, 2924, 279, 1429, 9760, 2530, 3542, 504, 279, 2390, 438, 17749, 2458, 10695, 624, 16, 19, 13, 1416, 279, 2390, 374, 1602, 4285, 11, 279, 2197, 5944, 1265, 387, 438, 4285, 438, 3204, 624, 16, 20, 13, 1416, 902, 2697, 3542, 3000, 320, 68, 1302, 2572, 678, 525, 5335, 11, 7868, 3542, 11, 4992, 24389, 498, 1265, 537, 6923, 894, 9293, 624, 16, 21, 13, 576, 1482, 2038, 3152, 702, 220, 2697, 3542, 11, 421, 279, 1372, 315, 2697, 3542, 374, 2686, 1091, 220, 20, 15, 11, 279, 2197, 5944, 646, 1172, 614, 825, 2188, 11, 323, 1172, 264, 9814, 16800, 311, 279, 3542, 374, 3897, 198, 16, 22, 13, 4320, 944, 3331, 1182, 13, 20678, 432, 697, 678, 382, 5097, 15042, 510, 785, 1590, 2550, 1265, 387, 264, 4718, 5944, 14064, 279, 9705, 28922, 13, 5443, 279, 2701, 3561, 1447, 334, 8973, 46817, 30990, 52225, 27972, 28763, 25, 1019, 16, 13, 1446, 27732, 1191, 697, 2033, 448, 366, 76303, 38283, 29, 320, 67685, 4772, 340, 17, 13, 1446, 27732, 835, 697, 2033, 448, 690, 76303, 38283, 29, 320, 85777, 4772, 340, 18, 13, 3155, 4183, 990, 894, 1008, 3561, 476, 9492, 198, 19, 13, 576, 4583, 6358, 1265, 387, 19472, 1948, 366, 76303, 38283, 29, 323, 690, 76303, 38283, 29, 9492, 198, 20, 13, 576, 9934, 27732, 387, 438, 11682, 438, 3204, 11, 37838, 1128, 2213, 3880, 311, 387, 5230, 323, 11682, 18821, 389, 279, 5944, 315, 279, 2213, 271, 334, 5370, 28596, 24580, 25, 1019, 12, 2265, 25, 3070, 8164, 334, 3772, 12, 15909, 198, 12, 829, 25, 3070, 8164, 334, 11113, 829, 198, 12, 17749, 2458, 25, 3070, 8164, 334, 1759, 315, 17749, 3542, 369, 12453, 1995, 11, 1084, 374, 264, 8674, 1034, 1815, 448, 5091, 311, 279, 12542, 3704, 6220, 198, 12, 9934, 25, 3070, 8164, 1019, 262, 481, 1752, 279, 12126, 3772, 476, 3704, 6193, 25, 4230, 15817, 2213, 369, 419, 3772, 10735, 389, 508, 56481, 33635, 39892, 70303, 14, 71913, 936, 81917, 1181, 7428, 11, 17646, 11, 323, 5025, 311, 1008, 6813, 13, 11789, 279, 8129, 3565, 11, 6546, 2606, 11, 323, 10431, 12624, 13, 29734, 2176, 43801, 916, 5072, 369, 46850, 323, 10916, 3565, 369, 10321, 13402, 13, 5443, 56626, 12966, 448, 279, 2038, 3152, 13, 39565, 14976, 10295, 44196, 4185, 990, 5048, 13, 11789, 584, 24099, 11, 5029, 11, 323, 470, 2750, 13, 29734, 46187, 1380, 8311, 311, 40368, 1376, 18940, 10346, 262, 481, 1752, 4565, 3565, 25, 7843, 11682, 2213, 369, 419, 1186, 91840, 476, 1186, 41387, 3772, 13, 663, 14438, 398, 10339, 8129, 3565, 11, 28696, 5025, 11, 24099, 11, 7947, 1614, 323, 10431, 12624, 13, 29734, 14175, 10295, 504, 279, 5042, 2038, 3152, 13, 11789, 6546, 2606, 11, 5029, 11, 323, 470, 2750, 13, 81917, 11871, 448, 1008, 6813, 13, 9177, 4185, 4714, 323, 862, 9904, 13, 7405, 2213, 15614, 311, 46850, 1393, 8241, 14016, 10916, 7990, 369, 10321, 13402, 624, 262, 481, 1752, 5333, 2197, 25, 4230, 5333, 9705, 369, 508, 56481, 33635, 5333, 11176, 21531, 15792, 11838, 936, 1752, 25414, 1262, 33356, 11, 2197, 10130, 5413, 11, 5548, 12624, 11, 1681, 98804, 61800, 11, 323, 16653, 5413, 13, 1752, 47042, 33356, 11, 2197, 3633, 11589, 11, 1943, 19856, 11, 1538, 4494, 11, 323, 1931, 7246, 16230, 12624, 13, 1752, 20954, 33356, 11, 2197, 3633, 31785, 11, 821, 14087, 11, 7868, 19856, 11, 323, 1584, 6240, 13, 1752, 45833, 16341, 3444, 10535, 11, 2197, 821, 16842, 11, 1943, 12299, 11, 323, 1882, 57912, 13, 29734, 11507, 18906, 10295, 11, 1465, 11589, 14830, 11, 4763, 37764, 11, 4379, 32894, 11, 323, 2319, 287, 1995, 13, 11789, 4185, 990, 5048, 11, 2943, 8129, 17501, 11, 323, 5068, 25262, 10414, 13, 9177, 11507, 18906, 27703, 7375, 323, 16558, 19827, 13, 1416, 8415, 11, 3410, 11906, 27193, 369, 31590, 4419, 323, 28412, 24748, 8388, 624, 262, 481, 1752, 42463, 9705, 25, 4230, 42463, 9705, 369, 508, 56481, 33635, 70303, 936, 60785, 279, 1550, 11591, 2884, 11, 42463, 12624, 11, 323, 1849, 22711, 13, 11789, 3692, 21880, 11, 821, 27455, 11, 323, 17590, 12624, 13, 81917, 10916, 11181, 11, 6559, 63939, 11, 323, 16982, 13, 29734, 13737, 8502, 11, 93740, 37764, 11, 323, 23172, 44882, 13, 39565, 1849, 2266, 46187, 323, 3692, 29985, 82, 13, 9177, 5312, 42221, 1280, 10520, 1075, 4763, 11, 16558, 11, 323, 20763, 13351, 13, 11789, 5440, 5611, 11, 4843, 24031, 19543, 11, 323, 2319, 24748, 624, 262, 481, 1752, 821, 1614, 25, 4230, 15817, 821, 1614, 9705, 369, 508, 56481, 33635, 47970, 14, 3540, 35839, 936, 25771, 5387, 11871, 11, 2070, 17473, 11, 323, 821, 4494, 13, 11789, 6028, 14, 28443, 6894, 11, 24953, 11, 323, 16982, 13, 81917, 821, 10519, 5601, 323, 2562, 5601, 13, 29734, 4625, 10802, 46187, 323, 6077, 821, 13, 11789, 821, 2615, 12624, 11, 47430, 14830, 11, 323, 5068, 37764, 13, 47395, 821, 47508, 11, 37131, 10186, 11, 323, 93947, 5601, 13, 29734, 821, 11906, 12716, 323, 2319, 6240, 13, 9177, 821, 4763, 11, 12345, 8502, 11, 323, 2615, 2524, 624, 262, 481, 1752, 3689, 6813, 25, 4230, 11682, 9705, 369, 508, 56481, 33635, 3689, 70303, 936, 60785, 279, 3692, 594, 9124, 11094, 11, 7709, 11, 323, 1196, 16230, 12624, 13, 11789, 6914, 14, 12340, 11, 4357, 11, 15711, 11, 323, 48041, 2606, 13, 29734, 10431, 10295, 448, 2038, 68642, 323, 3887, 67253, 13, 39565, 17501, 369, 25988, 2884, 323, 39700, 8733, 13, 11789, 3692, 5302, 11, 26053, 11, 323, 33592, 13, 29734, 1707, 48041, 2606, 323, 1105, 287, 1824, 13, 9177, 5312, 31555, 24748, 323, 5068, 25262, 13, 11789, 3692, 18037, 12624, 323, 17590, 448, 1008, 3689, 5424, 624, 12, 2841, 26564, 25, 3070, 8164, 334, 21149, 911, 1246, 311, 1186, 59394, 279, 1790, 11591, 2197, 5944, 311, 2841, 476, 902, 4623, 44373, 374, 4362, 13, 8278, 6718, 1119, 1186, 21599, 82, 979, 279, 2213, 374, 6351, 26, 421, 279, 2213, 374, 1602, 4285, 11, 5648, 26541, 44373, 624, 262, 481, 3070, 13314, 25, 1019, 286, 481, 576, 4419, 22903, 525, 2238, 6351, 369, 825, 2197, 311, 3421, 678, 279, 6540, 11, 323, 1184, 311, 387, 6718, 3772, 16, 11, 3772, 17, 1119, 279, 2841, 3772, 198, 286, 481, 576, 4688, 5610, 5248, 4419, 1741, 438, 4565, 16, 11, 4565, 17, 11, 4992, 2572, 323, 3880, 311, 387, 7481, 304, 7716, 304, 279, 22848, 82, 624, 286, 481, 576, 4565, 374, 2238, 4285, 13, 576, 1482, 2197, 374, 14016, 311, 3421, 279, 16800, 11, 902, 1184, 311, 6718, 624, 286, 481, 576, 2197, 374, 458, 23251, 943, 323, 1265, 387, 63594, 11, 902, 4623, 44373, 374, 4362, 624, 286, 481, 576, 2681, 2197, 374, 537, 264, 2477, 12, 48482, 943, 1741, 438, 458, 23251, 476, 1273, 11, 323, 5610, 5248, 6351, 1186, 77430, 26, 279, 1482, 2197, 1172, 5707, 264, 4586, 23251, 315, 1493, 1186, 77430, 323, 7460, 4623, 59822, 323, 11682, 16148, 624, 286, 481, 472, 33880, 374, 2238, 5538, 311, 387, 4623, 66509, 4490, 624, 286, 481, 2308, 4623, 44373, 374, 4362, 624, 286, 481, 24369, 9293, 323, 862, 1186, 11527, 2831, 11, 1741, 438, 23251, 11, 7497, 11, 23172, 11, 87577, 11, 323, 6200, 18940, 11, 1265, 387, 7481, 2878, 220, 16, 311, 220, 17, 5866, 11, 2041, 26541, 44373, 624, 286, 481, 24515, 3855, 11, 5440, 5611, 11, 28720, 3880, 311, 387, 438, 4285, 438, 3204, 11, 11523, 304, 264, 3175, 2197, 11, 2041, 44373, 1119, 22848, 82, 198, 12, 702, 31206, 25, 3070, 8164, 334, 4636, 2841, 3119, 11, 498, 27732, 1795, 279, 2841, 3119, 8502, 311, 6718, 279, 2197, 1186, 7837, 1119, 2841, 382, 334, 5370, 28596, 33784, 25, 1019, 12, 3070, 74909, 45234, 20395, 95518, 5443, 24034, 2841, 369, 19819, 1874, 819, 198, 12, 3070, 1092, 52899, 55669, 95518, 29734, 1449, 5089, 12893, 2041, 83118, 198, 12, 3070, 1703, 38446, 95518, 5948, 1817, 3772, 311, 9760, 2530, 3542, 369, 13403, 198, 12, 3070, 57468, 4874, 17582, 12754, 95518, 8886, 9934, 1969, 387, 11682, 11, 1917, 35085, 11, 323, 31554, 3151, 198, 12, 3070, 37889, 4874, 28596, 95518, 86377, 220, 16, 12, 19, 5866, 315, 7990, 7192, 369, 1429, 9705, 271, 334, 9620, 59501, 52517, 25, 1019, 4854, 3772, 9934, 1969, 2924, 510, 12, 3070, 47514, 5578, 25806, 95518, 12023, 21892, 315, 1128, 12893, 374, 1660, 26372, 198, 12, 3070, 36850, 12309, 95518, 26668, 7990, 8311, 369, 279, 3692, 23094, 198, 12, 3070, 3533, 36019, 25311, 95518, 36532, 2038, 10295, 323, 10431, 12624, 504, 279, 5042, 2038, 3152, 198, 12, 3070, 52464, 9608, 95518, 2585, 419, 3692, 35616, 311, 3800, 304, 279, 1849, 198, 12, 3070, 90535, 65, 50240, 81561, 95518, 7718, 4714, 323, 862, 9904, 198, 12, 3070, 34791, 21144, 804, 95518, 57739, 10414, 323, 1850, 12378, 1380, 9760, 271, 334, 9620, 15044, 9680, 52517, 25, 1019, 12, 21144, 3425, 323, 1246, 311, 6718, 279, 1790, 2188, 315, 3772, 198, 12, 1416, 279, 1482, 3772, 5610, 5248, 1186, 98516, 4419, 11, 279, 1790, 2188, 3772, 1265, 387, 6718, 705, 198, 12, 421, 432, 374, 2669, 279, 24632, 3692, 476, 4565, 11, 1513, 944, 6718, 432, 624, 12, 70615, 369, 264, 8172, 1948, 2197, 7990, 323, 10306, 2897, 198, 12, 34006, 26541, 66710, 429, 3643, 10646, 5000, 198, 12, 5787, 23251, 11, 16800, 11, 5440, 5611, 11, 3709, 3855, 11, 7497, 11, 23172, 11, 28720, 323, 862, 1186, 21599, 82, 1265, 537, 1373, 2841, 14158, 438, 1753, 438, 3204, 198, 12, 24515, 3855, 11, 5440, 5611, 3880, 311, 387, 438, 4285, 438, 3204, 11, 11523, 304, 264, 3175, 2197, 11, 2041, 44373, 1119, 22848, 82, 271, 334, 9620, 15044, 38297, 25, 1019, 12, 10548, 311, 279, 2841, 3119, 2213, 11, 421, 1052, 374, 902, 1184, 311, 17576, 1495, 279, 30106, 2188, 11, 738, 702, 31206, 311, 364, 14004, 6967, 5937, 11, 738, 702, 31206, 311, 364, 8996, 3876, 31115, 697, 2033, 1667, 419, 4734, 4718, 5944, 11, 448, 902, 5107, 1467, 1573, 476, 1283, 510, 27, 76303, 38283, 397, 515, 220, 330, 3615, 788, 2278, 262, 341, 414, 330, 2102, 788, 330, 2809, 12, 15909, 756, 414, 330, 606, 788, 330, 9620, 3988, 756, 414, 330, 37555, 2458, 788, 4383, 20432, 2343, 32429, 10758, 8367, 23903, 16, 9220, 497, 330, 20432, 2343, 32429, 10758, 8367, 23903, 17, 9220, 8097, 414, 330, 40581, 788, 330, 4021, 15817, 2213, 2503, 476, 7843, 11682, 2213, 369, 2503, 756, 414, 330, 5864, 26564, 788, 330, 30438, 414, 330, 4648, 31206, 788, 330, 14004, 476, 5664, 698, 262, 456, 220, 5133, 532, 522, 76303, 38283, 1339, 60424, 697, 1590, 12542, 5944, 2213, 2878, 366, 76303, 38283, 29, 9492, 13, 29734, 902, 40841, 476, 6042, 4889, 315, 1493, 9492, 382, 334, 98743, 25, 1446, 27732, 5889, 304, 8453, 7, 104811, 8, 659, 334, 151645, 198, 151644, 872, 271, 8387, 2197, 5944, 4916, 510, 27, 3765, 26231, 38283, 397, 23888, 198, 144663, 16991, 220, 73345, 99706, 99778, 198, 144663, 16991, 88940, 104, 94299, 55286, 198, 144663, 16991, 93920, 114, 77835, 70500, 198, 144663, 16991, 51461, 116, 63109, 110195, 198, 144663, 16991, 34369, 109, 71743, 44956, 72448, 198, 144663, 16991, 18137, 44104, 21596, 39352, 198, 144663, 16991, 18137, 225, 101, 100463, 113308, 198, 144663, 16991, 5333, 26853, 224, 77598, 320, 5405, 2197, 340, 144663, 16991, 62262, 104949, 198, 144663, 16991, 41479, 231, 35987, 70500, 198, 144663, 16991, 90476, 100, 26232, 57218, 106375, 33071, 198, 144663, 16991, 18137, 249, 228, 12857, 57218, 106375, 198, 144663, 16991, 98313, 66635, 104238, 198, 144663, 16991, 83002, 98, 76813, 57218, 102556, 74220, 198, 144663, 16991, 43614, 227, 99884, 105853, 57218, 101536, 86119, 198, 144798, 16991, 81947, 28291, 105866, 198, 522, 3765, 26231, 38283, 1339, 27, 4310, 478, 397, 12, 24369, 9293, 323, 862, 1186, 11527, 2831, 11, 1741, 438, 23251, 11, 7497, 11, 23172, 11, 87577, 11, 323, 6200, 18940, 11, 1265, 387, 7481, 2878, 220, 16, 311, 220, 17, 5866, 11, 2041, 26541, 44373, 624, 12, 24515, 3855, 11, 5440, 5611, 11, 28720, 3880, 311, 387, 438, 4285, 438, 3204, 11, 11523, 304, 264, 3175, 2197, 11, 2041, 44373, 1119, 22848, 82, 198, 522, 4310, 478, 1339, 5097, 1172, 279, 1790, 2188, 2197, 5944, 315, 279, 2701, 2197, 510, 27, 3231, 26231, 397, 13608, 2102, 1210, 364, 2068, 72623, 516, 364, 606, 1210, 364, 7082, 26853, 224, 77598, 516, 364, 37555, 2458, 1210, 2509, 8092, 14, 8092, 4030, 37255, 18002, 516, 364, 8611, 34827, 4030, 37255, 18002, 516, 364, 50395, 14, 13131, 388, 2836, 37255, 18002, 516, 364, 22803, 18008, 4130, 4030, 37255, 18002, 516, 364, 5834, 21492, 76196, 4030, 37255, 18002, 516, 364, 15110, 4322, 17, 79, 4322, 17, 79, 57322, 4089, 364, 40581, 1210, 364, 50377, 16447, 3366, 10236, 36097, 54658, 9370, 100873, 5333, 26853, 224, 77598, 111116, 1773, 100700, 65577, 40549, 32112, 5333, 34369, 120, 36629, 33071, 101884, 3837, 100630, 100811, 65101, 83751, 72225, 5373, 105713, 39352, 5373, 35112, 6567, 36548, 9370, 10130, 10236, 104, 107, 27442, 1773, 100157, 99200, 110195, 106588, 101979, 10130, 5333, 3837, 100630, 20713, 5373, 13298, 5373, 31133, 5373, 16219, 5373, 11066, 12, 1552, 220, 104186, 104516, 107736, 1773, 66833, 41932, 393, 17, 47, 66521, 237, 96422, 101931, 3837, 104210, 18433, 4322, 17, 79, 4322, 17, 79, 57322, 41479, 248, 64559, 105778, 68805, 5373, 64064, 101136, 5373, 118376, 102054, 33108, 20074, 107468, 100674, 1773, 65577, 99722, 101071, 78882, 27442, 3837, 100630, 105266, 33071, 101071, 5373, 80158, 100829, 33071, 101071, 33108, 104118, 107047, 107736, 1773, 99553, 108069, 85767, 5333, 3837, 100630, 104299, 85767, 50007, 5373, 44091, 51154, 5373, 113308, 40090, 107736, 1773, 32664, 103991, 5333, 10236, 104, 107, 27442, 99553, 100700, 66394, 5122, 9230, 81454, 5373, 3144, 6567, 44401, 28330, 5373, 34859, 14, 102808, 100144, 5373, 104510, 39907, 5373, 32100, 46100, 33108, 19793, 26355, 1773, 100630, 5333, 64388, 230, 21894, 39352, 5373, 69041, 33447, 114288, 33071, 5373, 47149, 88653, 104238, 33108, 99464, 101118, 1773, 99553, 104621, 101884, 105866, 5373, 31534, 85658, 19793, 26355, 33108, 102179, 100419, 1773, 516, 364, 5864, 26564, 1210, 364, 7082, 101275, 102298, 101312, 99604, 109963, 7082, 107736, 3837, 73157, 86402, 101103, 101970, 101136, 33108, 105795, 3837, 85106, 101348, 102239, 5122, 35, 13659, 32112, 5333, 5373, 110195, 17881, 9230, 5333, 5373, 47, 17, 47, 101136, 107736, 5373, 99722, 101071, 57218, 39352, 7082, 49567, 44729, 113066, 71817, 100700, 66394, 1773, 516, 364, 4648, 31206, 1210, 364, 14004, 16418, 522, 3231, 26231, 1339, 334, 98743, 25, 1019, 12, 3070, 37889, 4874, 28596, 95518, 9125, 5944, 2188, 220, 17, 11, 10306, 220, 16, 12, 19, 5866, 315, 7990, 7192, 369, 1429, 9705, 198, 12, 1416, 279, 2841, 5944, 2188, 374, 2238, 5538, 11, 902, 4623, 44373, 374, 4362, 624, 12, 9258, 5043, 1969, 6644, 2265, 11, 829, 11, 17749, 2458, 11, 9934, 11, 2841, 26564, 11, 323, 702, 31206, 624, 12, 1446, 27732, 1191, 697, 2033, 448, 366, 76303, 38283, 29, 320, 67685, 4772, 8, 323, 835, 448, 690, 76303, 38283, 29, 320, 85777, 4772, 340, 12, 3155, 537, 990, 1140, 4334, 18822, 311, 6851, 6220, 476, 1034, 1995, 429, 374, 2669, 3897, 304, 279, 2390, 5944, 624, 12, 3155, 537, 2550, 2197, 429, 525, 279, 1852, 438, 279, 1482, 2197, 624, 12, 1446, 60875, 5889, 304, 8453, 7, 104811, 8, 421, 3204, 13, 151645, 198, 151644, 77091, 198, 151657, 198, 27, 1688, 96598, 4136, 3152, 397, 27, 16181, 79194, 397, 35, 13659, 32112, 5333, 10130, 36342, 24083, 11291, 198, 522, 16181, 397, 27, 16181, 86582, 18981, 397, 29172, 11, 2068, 42717, 7706, 198, 522, 16181, 397, 522, 1688, 397, 151658, 151645, 198, 151644, 872, 198, 151665, 198, 9485, 525, 279, 2038, 2711, 3059, 13, 7036, 429, 1493, 1231, 476, 1231, 537, 2924, 279, 2480, 4226, 13, 1084, 374, 705, 311, 498, 311, 10279, 421, 498, 1184, 311, 1618, 803, 7375, 311, 9567, 803, 1995, 382, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 18008, 4130, 14, 29172, 9199, 37255, 18002, 69233, 20, 15, 8125, 24, 17, 198, 322, 32149, 96059, 374, 264, 22188, 1618, 892, 8473, 274, 624, 2830, 320, 82, 353, 5475, 8, 32149, 96059, 368, 1465, 341, 6725, 70419, 445, 24617, 19424, 2812, 3538, 389, 1018, 82, 497, 274, 5423, 64091, 340, 853, 11446, 83535, 1141, 5423, 64091, 11, 274, 31010, 2398, 630, 1313, 16403, 2582, 2036, 341, 197, 44814, 3056, 917, 1565, 2236, 2974, 81907, 8805, 630, 322, 16403, 3050, 13469, 16403, 1681, 624, 322, 3703, 1110, 14120, 91131, 905, 14, 29172, 45389, 10508, 26559, 25584, 369, 803, 5785, 624, 2830, 320, 82, 353, 5475, 8, 16403, 3050, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 1465, 341, 8810, 2353, 48, 1669, 330, 77, 698, 40668, 48, 1669, 330, 4259, 1837, 197, 322, 7854, 1681, 369, 1759, 2354, 44265, 624, 2405, 4051, 4772, 2972, 5814, 5632, 198, 10676, 1669, 435, 20893, 198, 18534, 1669, 575, 15685, 741, 2023, 595, 11, 348, 1669, 2088, 2804, 341, 197, 743, 2422, 3747, 8, 961, 220, 16, 341, 298, 853, 7013, 13080, 1006, 571, 197, 1, 11808, 3239, 1018, 82, 7533, 82, 497, 595, 11, 348, 568, 2522, 19886, 69497, 340, 197, 197, 532, 197, 8961, 595, 341, 197, 2722, 3930, 48, 510, 298, 8810, 2353, 2507, 11, 1848, 1669, 33317, 67107, 3747, 58, 15, 2546, 298, 743, 1848, 961, 2092, 341, 571, 853, 7013, 13080, 1006, 464, 197, 1, 11808, 3930, 1018, 82, 25, 1018, 82, 497, 348, 11, 1848, 568, 2522, 19886, 69497, 340, 298, 197, 532, 298, 743, 3930, 2507, 621, 220, 15, 341, 571, 853, 7013, 13080, 1006, 464, 197, 1, 11808, 3930, 1018, 67, 497, 3930, 2507, 568, 2522, 19886, 69497, 340, 298, 197, 532, 298, 50108, 1214, 2353, 284, 3930, 2507, 198, 197, 2722, 4347, 48, 510, 298, 50108, 61958, 284, 348, 58, 15, 921, 197, 11940, 510, 298, 853, 7013, 13080, 445, 11808, 3239, 1018, 82, 497, 595, 568, 2522, 19886, 69497, 340, 197, 197, 532, 197, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 18008, 4130, 14, 29172, 9199, 37255, 18002, 69233, 15, 8125, 19, 23, 198, 322, 2955, 320, 66, 8, 220, 17, 15, 16, 21, 12, 17, 15, 16, 24, 27612, 24561, 11, 4848, 624, 2289, 322, 10103, 1212, 279, 8914, 1876, 11, 6079, 220, 17, 13, 15, 320, 1782, 330, 9827, 797, 322, 498, 1231, 537, 990, 419, 1034, 3650, 304, 8733, 448, 279, 1876, 624, 322, 1446, 1231, 6851, 264, 2975, 315, 279, 1876, 518, 198, 2289, 322, 28080, 1110, 2136, 5096, 2659, 6971, 10845, 12, 17, 13, 15, 198, 2289, 322, 10878, 2567, 553, 8415, 2329, 476, 7230, 311, 304, 4378, 11, 3162, 198, 322, 4237, 1212, 279, 1876, 374, 4237, 389, 458, 330, 1911, 3424, 1, 11389, 345, 322, 6007, 7427, 2726, 11342, 3008, 4137, 9297, 11, 2987, 3158, 476, 6131, 624, 322, 3496, 279, 1876, 369, 279, 3151, 4128, 10012, 8541, 323, 198, 322, 9481, 1212, 279, 1876, 624, 1722, 19424, 9199, 271, 474, 2399, 197, 42884, 8931, 698, 197, 21871, 698, 197, 42200, 698, 197, 32468, 15627, 698, 197, 32468, 57254, 698, 197, 59658, 698, 197, 39744, 1837, 197, 9749, 905, 25525, 11582, 72, 14, 14604, 698, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 30593, 21492, 76196, 2972, 698, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 21902, 7530, 5252, 698, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 21902, 20936, 798, 698, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 21902, 19413, 698, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 21902, 54544, 746, 698, 692, 322, 8422, 35455, 40549, 19424, 36342, 624, 1313, 8422, 2036, 341, 25873, 262, 5532, 198, 60439, 2959, 4772, 2972, 11716, 198, 630, 322, 1532, 5475, 11450, 264, 501, 8422, 624, 2830, 1532, 5475, 8754, 5532, 11, 4772, 2959, 4772, 2972, 11716, 8, 353, 5475, 341, 853, 609, 5475, 90, 1676, 11, 4772, 2959, 532, 630, 322, 19954, 4675, 264, 7013, 369, 274, 624, 2830, 320, 82, 353, 5475, 8, 19954, 368, 1758, 31010, 341, 7000, 1669, 25798, 7121, 9523, 741, 7000, 2234, 4283, 85, 17, 19632, 26539, 497, 7013, 38968, 1141, 62559, 3050, 1171, 853, 435, 198, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 45714, 8749, 4322, 617, 261, 15351, 18002, 69233, 19, 24, 8125, 21, 21, 198, 197, 322, 16571, 3050, 33045, 2168, 3055, 432, 21189, 11540, 504, 26588, 19424, 624, 197, 322, 1084, 374, 264, 1293, 4303, 1882, 624, 197, 18553, 11, 1848, 1669, 1532, 11196, 3050, 7, 17, 15, 15, 11, 26588, 340, 743, 1848, 961, 2092, 341, 197, 6725, 26133, 3964, 340, 197, 630, 67009, 1669, 59807, 7121, 9523, 741, 67009, 63623, 35460, 6267, 3050, 4292, 197, 197, 17856, 445, 3806, 1138, 67009, 63623, 4283, 38188, 497, 11540, 31010, 4292, 197, 197, 17856, 445, 2946, 1138, 67009, 63623, 4283, 38188, 497, 6267, 3050, 4292, 197, 197, 17856, 445, 3806, 5130, 6725, 70419, 445, 74819, 389, 1018, 82, 497, 8844, 13986, 340, 6725, 26133, 19886, 68334, 96059, 7307, 268, 13986, 11, 9273, 1171, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 18008, 4130, 18008, 4130, 4030, 37255, 18002, 69233, 19, 22, 8125, 23, 16, 198, 853, 609, 5475, 515, 197, 79659, 345, 197, 197, 3564, 4703, 19963, 3050, 12805, 11, 65949, 1326, 197, 197, 3564, 29978, 2995, 3050, 12805, 11, 4772, 2959, 11, 609, 3675, 5668, 6570, 22655, 10472, 11, 65949, 1326, 197, 25873, 345, 197, 532, 630, 322, 19954, 4675, 279, 10130, 7013, 624, 2830, 320, 82, 353, 5475, 8, 19954, 368, 1758, 31010, 341, 7000, 1669, 25798, 7121, 9523, 2822, 7000, 9046, 1255, 11603, 10538, 14099, 1141, 29856, 1171, 7000, 9046, 1255, 11603, 1214, 266, 2251, 10105, 1141, 29856, 4390, 7000, 2234, 4283, 12120, 497, 7013, 38968, 1141, 38107, 3050, 4390, 7000, 23442, 4283, 29172, 14, 38188, 497, 7013, 38968, 1141, 6225, 19963, 3050, 31421, 1171, 7000, 23442, 4283, 22803, 5457, 16, 14, 29172, 4322, 1097, 2995, 497, 274, 63559, 2995, 3050, 31421, 692, 197, 322, 328, 13280, 608, 8349, 87146, 299, 69, 36342, 624, 7000, 1321, 629, 35460, 1758, 13275, 60421, 44, 2200, 692, 853, 435, 198, 630, 2830, 320, 82, 353, 5475, 8, 2820, 3050, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 1465, 341, 11009, 991, 33655, 3622, 11, 330, 3925, 1138, 853, 2092, 198, 630, 2830, 320, 82, 353, 5475, 8, 32149, 96059, 368, 1465, 341, 6725, 70419, 445, 24617, 13291, 3538, 389, 1018, 82, 497, 274, 5423, 64091, 340, 853, 11446, 83535, 1141, 5423, 64091, 11, 274, 31010, 2398, 3733, 151666, 198, 151645, 198, 151644, 77091, 198, 151657, 198, 27, 1688, 96598, 4136, 3152, 397, 27, 16181, 86582, 18981, 397, 19794, 11, 55627, 42717, 7706, 198, 522, 16181, 397, 27, 16181, 79194, 397, 9230, 5333, 11291, 24083, 36342, 3538, 198, 522, 16181, 397, 27, 16181, 96598, 23199, 397, 8092, 14, 8092, 4030, 198, 522, 16181, 397, 522, 1688, 397, 151658, 151645, 198, 151644, 872, 198, 151665, 198, 9485, 525, 279, 2038, 2711, 3059, 13, 7036, 429, 1493, 1231, 476, 1231, 537, 2924, 279, 2480, 4226, 13, 1084, 374, 705, 311, 498, 311, 10279, 421, 498, 1184, 311, 1618, 803, 7375, 311, 9567, 803, 1995, 382, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 21902, 7530, 5252, 7530, 5252, 18002, 69233, 20, 15, 8125, 16, 15, 15, 198, 322, 12104, 11367, 264, 2526, 4247, 311, 384, 624, 2830, 320, 68, 353, 1454, 8, 12104, 5969, 11, 348, 914, 8, 353, 1454, 341, 7727, 17093, 1904, 5969, 11, 348, 340, 853, 384, 198, 630, 322, 2126, 2522, 4675, 279, 1465, 2639, 624, 2830, 320, 68, 353, 1454, 8, 2126, 2522, 368, 526, 341, 853, 384, 4299, 198, 630, 2830, 320, 68, 353, 1454, 8, 4600, 368, 914, 341, 743, 384, 15137, 621, 1591, 341, 197, 853, 8879, 17305, 445, 4030, 1465, 1018, 67, 497, 384, 4299, 340, 197, 532, 853, 8879, 17305, 445, 4030, 1465, 1018, 67, 25, 1018, 82, 497, 384, 4299, 11, 384, 15137, 340, 630, 322, 15495, 3050, 18653, 458, 10130, 7013, 892, 4675, 458, 1465, 624, 1313, 15495, 3050, 2915, 19886, 37508, 11, 353, 1254, 9659, 8, 1465, 271, 322, 42187, 32722, 458, 15495, 3050, 1119, 458, 1758, 89164, 553, 11589, 279, 1465, 198, 322, 5927, 553, 305, 624, 2830, 42187, 3203, 15495, 3050, 8, 1758, 89164, 341, 853, 2915, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 341, 197, 2405, 2639, 526, 198, 197, 2405, 60078, 914, 198, 197, 743, 1848, 1669, 305, 3622, 11, 435, 1215, 1848, 961, 2092, 341, 298, 8961, 384, 1669, 1848, 12832, 1313, 8, 341, 298, 2722, 353, 1454, 510, 571, 2023, 595, 11, 6165, 1669, 2088, 384, 17093, 341, 464, 2023, 8358, 348, 1669, 2088, 6165, 341, 1144, 6692, 15753, 1005, 2212, 5969, 11, 348, 340, 464, 197, 532, 571, 197, 532, 571, 23847, 284, 384, 4299, 198, 571, 9859, 6611, 284, 384, 15137, 198, 298, 11940, 510, 571, 23847, 284, 1758, 66760, 198, 571, 9859, 6611, 284, 384, 6141, 741, 298, 197, 532, 298, 6692, 69794, 13838, 340, 298, 6692, 4073, 10556, 3782, 3964, 6611, 1171, 197, 197, 92, 770, 341, 298, 23847, 284, 1758, 52989, 198, 197, 197, 532, 197, 743, 2639, 2604, 220, 19, 15, 15, 1009, 2639, 961, 220, 19, 15, 19, 341, 298, 6725, 70419, 4430, 67, 1018, 82, 1018, 82, 1018, 82, 497, 2639, 11, 435, 20798, 11, 435, 20893, 17474, 11, 60078, 340, 197, 197, 532, 197, 532, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 8194, 3183, 11603, 3183, 11603, 18002, 69233, 21, 15, 8125, 24, 24, 198, 322, 9926, 2251, 10105, 10953, 14887, 6844, 5946, 624, 2830, 9926, 2251, 10105, 50714, 52295, 77940, 8, 2915, 16913, 1758, 31010, 8, 1758, 31010, 341, 853, 2915, 16913, 1758, 31010, 8, 1758, 31010, 341, 197, 853, 1758, 89164, 18552, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 341, 298, 21375, 1669, 882, 13244, 741, 298, 28144, 83535, 9230, 3622, 11, 435, 340, 298, 60439, 27380, 50714, 11, 435, 568, 10105, 445, 5524, 2251, 1827, 6471, 9730, 93404, 10639, 1171, 197, 197, 3518, 197, 532, 630, 1313, 3255, 2522, 6492, 2036, 341, 28080, 37508, 198, 6692, 5529, 4047, 1807, 198, 43343, 286, 526, 198, 630, 2830, 320, 86, 353, 8548, 2522, 6492, 8, 9645, 4047, 15842, 526, 8, 341, 743, 753, 86, 1418, 5529, 4047, 341, 197, 6692, 10210, 284, 2038, 198, 197, 6692, 1418, 5529, 4047, 284, 830, 198, 197, 6692, 37508, 69794, 15842, 340, 197, 532, 630, 2830, 320, 86, 353, 8548, 2522, 6492, 8, 9645, 1883, 3056, 3782, 8, 320, 396, 11, 1465, 8, 341, 6692, 69794, 19886, 52989, 340, 853, 289, 37508, 4073, 1883, 340, 630, 322, 8104, 14099, 10953, 14887, 2639, 1760, 624, 2830, 8104, 14099, 50714, 52295, 77940, 8, 2915, 16913, 1758, 31010, 8, 1758, 31010, 341, 853, 2915, 16913, 1758, 31010, 8, 1758, 31010, 341, 197, 853, 1758, 89164, 18552, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 341, 298, 71952, 86, 1669, 609, 8548, 2522, 6492, 90, 86, 11, 895, 11, 1758, 52989, 532, 298, 28144, 83535, 9230, 23794, 86, 11, 435, 340, 298, 60439, 27380, 50714, 11, 435, 568, 14099, 4199, 12027, 64109, 23794, 86, 10210, 4579, 39245, 7, 16, 340, 197, 197, 3518, 197, 532, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 18008, 4130, 18008, 4130, 4030, 37255, 18002, 69233, 19, 22, 8125, 23, 16, 198, 853, 609, 5475, 515, 197, 79659, 345, 197, 197, 3564, 4703, 19963, 3050, 12805, 11, 65949, 1326, 197, 197, 3564, 29978, 2995, 3050, 12805, 11, 4772, 2959, 11, 609, 3675, 5668, 6570, 22655, 10472, 11, 65949, 1326, 197, 25873, 345, 197, 532, 630, 322, 19954, 4675, 279, 10130, 7013, 624, 2830, 320, 82, 353, 5475, 8, 19954, 368, 1758, 31010, 341, 7000, 1669, 25798, 7121, 9523, 2822, 7000, 9046, 1255, 11603, 10538, 14099, 1141, 29856, 1171, 7000, 9046, 1255, 11603, 1214, 266, 2251, 10105, 1141, 29856, 4390, 7000, 2234, 4283, 12120, 497, 7013, 38968, 1141, 38107, 3050, 4390, 7000, 23442, 4283, 29172, 14, 38188, 497, 7013, 38968, 1141, 6225, 19963, 3050, 31421, 1171, 7000, 23442, 4283, 22803, 5457, 16, 14, 29172, 4322, 1097, 2995, 497, 274, 63559, 2995, 3050, 31421, 692, 197, 322, 328, 13280, 608, 8349, 87146, 299, 69, 36342, 624, 7000, 1321, 629, 35460, 1758, 13275, 60421, 44, 2200, 692, 853, 435, 198, 630, 2830, 320, 82, 353, 5475, 8, 2820, 3050, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 1465, 341, 11009, 991, 33655, 3622, 11, 330, 3925, 1138, 853, 2092, 198, 630, 2830, 320, 82, 353, 5475, 8, 32149, 96059, 368, 1465, 341, 6725, 70419, 445, 24617, 13291, 3538, 389, 1018, 82, 497, 274, 5423, 64091, 340, 853, 11446, 83535, 1141, 5423, 64091, 11, 274, 31010, 2398, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 30593, 21492, 76196, 4030, 37255, 18002, 69233, 23, 16, 8125, 16, 17, 21, 198, 25873, 284, 2193, 13045, 16273, 2822, 79659, 284, 10472, 23676, 3556, 9147, 14032, 30953, 515, 197, 197, 1, 4352, 788, 330, 4578, 4030, 756, 197, 8824, 853, 609, 5475, 515, 197, 25873, 25, 394, 2193, 345, 197, 79659, 25, 338, 10472, 345, 197, 77446, 1412, 25, 1060, 1182, 1412, 345, 197, 8854, 13298, 61088, 25, 286, 2205, 13298, 61088, 345, 197, 8854, 13298, 2959, 25, 257, 2205, 13298, 2959, 345, 197, 197, 79488, 25, 1797, 18709, 345, 197, 57279, 25, 338, 3553, 345, 197, 197, 1826, 6295, 25, 2290, 1299, 6295, 345, 197, 60439, 18327, 1693, 2043, 25, 4772, 18327, 1693, 2043, 345, 197, 197, 19979, 25, 1060, 9109, 345, 197, 197, 14891, 18190, 25, 1843, 2170, 18190, 345, 197, 532, 630, 322, 19954, 4675, 458, 1758, 31010, 369, 274, 624, 2830, 320, 82, 353, 5475, 8, 19954, 368, 1758, 31010, 341, 7000, 1669, 25798, 7121, 9523, 2822, 7000, 9046, 1255, 11603, 10538, 14099, 1141, 29856, 1171, 7000, 9046, 1255, 11603, 1214, 266, 2251, 10105, 1141, 29856, 4390, 7000, 2234, 4283, 12120, 497, 7013, 38968, 1141, 38107, 3050, 1171, 7000, 2234, 4283, 878, 1880, 497, 7013, 38968, 1141, 4125, 1880, 3973, 3050, 4390, 7000, 39825, 4283, 14082, 9388, 4578, 4472, 36339, 9388, 36339, 9545, 7013, 38968, 1141, 3597, 5668, 3050, 1171, 7000, 90478, 4283, 14082, 9388, 4578, 9545, 7013, 38968, 1141, 6858, 5668, 3050, 1171, 7000, 2234, 4283, 14082, 9388, 4578, 9545, 7013, 38968, 1141, 59279, 3050, 4390, 7000, 2234, 4283, 81907, 9388, 23476, 4472, 14082, 497, 7013, 38968, 1141, 6420, 4624, 3050, 4390, 7000, 2234, 4283, 1607, 1057, 497, 7013, 38968, 1141, 6420, 3050, 4390, 7000, 23442, 4283, 1826, 6295, 84460, 9388, 4578, 9545, 7013, 38968, 1141, 68225, 48795, 5668, 3050, 4390, 7000, 2234, 4283, 8611, 497, 7013, 38968, 1141, 670, 13298, 3050, 4390, 7000, 23442, 1006, 197, 197, 3115, 10481, 3446, 14070, 81771, 6295, 84460, 9388, 4578, 4472, 36339, 9388, 36339, 24375, 197, 53326, 38968, 1141, 950, 14070, 18327, 48795, 5668, 3050, 4390, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 8194, 3183, 11603, 3183, 11603, 4452, 18002, 69233, 16, 15, 18, 8125, 16, 20, 21, 198, 2830, 3393, 2522, 14099, 1155, 353, 8840, 836, 8, 341, 78216, 1669, 3056, 1235, 341, 197, 41653, 1843, 914, 198, 197, 53326, 286, 2915, 19886, 37508, 11, 353, 1254, 9659, 340, 197, 42400, 2522, 914, 198, 197, 59403, 197, 197, 515, 298, 197, 1, 3194, 7013, 14579, 220, 17, 15, 15, 756, 298, 29244, 19886, 37508, 11, 353, 1254, 9659, 8, 14573, 298, 197, 1, 17, 15, 15, 756, 197, 197, 2137, 341, 298, 197, 1, 59079, 1760, 220, 17, 15, 15, 756, 298, 29244, 3622, 1758, 37508, 11, 716, 353, 1254, 9659, 8, 314, 6399, 44747, 3622, 11, 330, 3925, 899, 1153, 298, 197, 1, 17, 15, 15, 756, 197, 197, 2137, 341, 298, 197, 1, 4934, 4247, 756, 298, 29244, 3622, 1758, 37508, 11, 716, 353, 1254, 9659, 8, 314, 289, 69794, 7, 20, 15, 15, 8, 1153, 298, 197, 1, 20, 15, 15, 756, 197, 197, 2137, 341, 298, 197, 1, 35673, 3270, 4247, 6738, 1172, 10953, 1156, 1618, 756, 298, 29244, 3622, 1758, 37508, 11, 716, 353, 1254, 9659, 8, 314, 289, 69794, 7, 19, 15, 15, 1215, 289, 69794, 7, 20, 15, 15, 8, 1153, 298, 197, 1, 19, 15, 15, 756, 197, 197, 1583, 197, 532, 2023, 8358, 1273, 1669, 2088, 7032, 341, 197, 3244, 16708, 8623, 30514, 11, 2915, 1155, 353, 8840, 836, 8, 341, 298, 17957, 1669, 1373, 7121, 1155, 692, 298, 79659, 1669, 52295, 7121, 2271, 10803, 19814, 2092, 692, 298, 7000, 1669, 25798, 7121, 9523, 741, 298, 7000, 9046, 38866, 14099, 50714, 1171, 298, 7000, 2234, 4283, 7975, 9388, 7975, 9545, 1273, 31171, 692, 298, 53183, 11, 2936, 1669, 1273, 1314, 12101, 5475, 2601, 340, 298, 16867, 2936, 2822, 298, 2023, 600, 1669, 220, 15, 26, 600, 366, 220, 20, 26, 600, 1027, 341, 571, 197, 6878, 1848, 1669, 1758, 2234, 28197, 17305, 445, 1254, 86791, 82, 60555, 10776, 497, 10789, 1171, 571, 17957, 35699, 3964, 340, 298, 197, 630, 298, 17957, 12808, 7, 16, 11, 2422, 50714, 808, 9601, 1005, 2507, 388, 12145, 298, 2023, 8358, 348, 1669, 2088, 10472, 808, 9601, 1005, 2507, 388, 368, 341, 571, 17957, 12808, 8623, 56835, 2522, 11, 348, 2967, 2398, 571, 17957, 12808, 1548, 21, 19, 7, 20, 701, 348, 6167, 2398, 571, 17957, 12808, 9147, 14032, 30953, 515, 464, 197, 1, 32540, 788, 330, 7975, 756, 464, 197, 1, 4393, 788, 256, 330, 3806, 756, 571, 197, 2137, 348, 73522, 2398, 298, 197, 532, 197, 197, 3518, 197, 532, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 21902, 7530, 5252, 7530, 5252, 18002, 69233, 15, 8125, 20, 19, 198, 322, 2955, 320, 66, 8, 220, 17, 15, 16, 21, 12, 17, 15, 16, 24, 27612, 24561, 11, 4848, 624, 2289, 322, 10103, 1212, 279, 8914, 1876, 11, 6079, 220, 17, 13, 15, 320, 1782, 330, 9827, 797, 322, 498, 1231, 537, 990, 419, 1034, 3650, 304, 8733, 448, 279, 1876, 624, 322, 1446, 1231, 6851, 264, 2975, 315, 279, 1876, 518, 198, 2289, 322, 28080, 1110, 2136, 5096, 2659, 6971, 10845, 12, 17, 13, 15, 198, 2289, 322, 10878, 2567, 553, 8415, 2329, 476, 7230, 311, 304, 4378, 11, 3162, 198, 322, 4237, 1212, 279, 1876, 374, 4237, 389, 458, 330, 1911, 3424, 1, 11389, 345, 322, 6007, 7427, 2726, 11342, 3008, 4137, 9297, 11, 2987, 3158, 476, 6131, 624, 322, 3496, 279, 1876, 369, 279, 3151, 4128, 10012, 8541, 323, 198, 322, 9481, 1212, 279, 1876, 624, 1722, 7013, 271, 474, 2399, 197, 21871, 698, 197, 32468, 15627, 1837, 197, 9749, 905, 14, 29870, 14, 9855, 3366, 21902, 19413, 698, 692, 322, 4600, 18653, 458, 10130, 7013, 1465, 892, 42569, 23156, 2639, 323, 7102, 198, 322, 311, 387, 738, 304, 279, 10130, 2033, 624, 1313, 4600, 2036, 341, 23847, 526, 198, 20883, 1758, 15753, 198, 21169, 262, 914, 198, 630, 322, 4600, 69, 11450, 264, 501, 4600, 448, 90908, 11297, 36566, 13, 35990, 311, 220, 20, 15, 15, 1465, 624, 2830, 4600, 69, 20698, 914, 11, 2827, 2503, 4970, 28875, 353, 1454, 341, 853, 609, 1454, 515, 197, 23847, 25, 1758, 66760, 345, 197, 20883, 25, 1758, 15753, 38837, 197, 21169, 25, 262, 8879, 17305, 20698, 11, 2827, 1112, 1326, 197, 532, 630, 322, 4600, 2522, 11450, 458, 4287, 1943, 1465, 448, 2639, 274, 624, 2830, 4600, 2522, 1141, 526, 8, 353, 1454, 341, 853, 4600, 69, 80821, 2522, 1141, 340, 630, 322, 8104, 7289, 264, 2526, 2639, 389, 384, 624, 2830, 320, 68, 353, 1454, 8, 8104, 1141, 526, 8, 353, 1454, 341, 7727, 4299, 284, 274, 198, 853, 384, 198, 630, 322, 12104, 11367, 264, 2526, 4247, 311, 384, 624, 2830, 320, 68, 353, 1454, 8, 12104, 5969, 11, 348, 914, 8, 353, 1454, 341, 7727, 17093, 1904, 5969, 11, 348, 340, 853, 384, 198, 630, 9612, 300, 10758, 17286, 14, 29870, 4698, 81, 3366, 18008, 4130, 14, 29172, 9199, 37255, 18002, 69233, 20, 15, 8125, 24, 17, 198, 322, 32149, 96059, 374, 264, 22188, 1618, 892, 8473, 274, 624, 2830, 320, 82, 353, 5475, 8, 32149, 96059, 368, 1465, 341, 6725, 70419, 445, 24617, 19424, 2812, 3538, 389, 1018, 82, 497, 274, 5423, 64091, 340, 853, 11446, 83535, 1141, 5423, 64091, 11, 274, 31010, 2398, 630, 1313, 16403, 2582, 2036, 341, 197, 44814, 3056, 917, 1565, 2236, 2974, 81907, 8805, 630, 322, 16403, 3050, 13469, 16403, 1681, 624, 322, 3703, 1110, 14120, 91131, 905, 14, 29172, 45389, 10508, 26559, 25584, 369, 803, 5785, 624, 2830, 320, 82, 353, 5475, 8, 16403, 3050, 3622, 1758, 37508, 11, 435, 353, 1254, 9659, 8, 1465, 341, 8810, 2353, 48, 1669, 330, 77, 698, 40668, 48, 1669, 330, 4259, 1837, 197, 322, 7854, 1681, 369, 1759, 2354, 44265, 624, 2405, 4051, 4772, 2972, 5814, 5632, 198, 10676, 1669, 435, 20893, 198, 18534, 1669, 575, 15685, 741, 2023, 595, 11, 348, 1669, 2088, 2804, 341, 197, 743, 2422, 3747, 8, 961, 220, 16, 341, 298, 853, 7013, 13080, 1006, 571, 197, 1, 11808, 3239, 1018, 82, 7533, 82, 497, 595, 11, 348, 568, 2522, 19886, 69497, 340, 197, 197, 532, 197, 8961, 595, 341, 197, 2722, 3930, 48, 510, 298, 8810, 2353, 2507, 11, 1848, 1669, 33317, 67107, 3747, 58, 15, 2546, 298, 743, 1848, 961, 2092, 341, 571, 853, 7013, 13080, 1006, 464, 197, 1, 11808, 3930, 1018, 82, 25, 1018, 82, 497, 348, 11, 1848, 568, 2522, 19886, 69497, 340, 298, 197, 532, 298, 743, 3930, 2507, 621, 220, 15, 341, 571, 853, 7013, 13080, 1006, 464, 197, 1, 11808, 3930, 1018, 67, 497, 3930, 2507, 568, 2522, 19886, 69497, 340, 298, 197, 532, 298, 50108, 1214, 2353, 284, 3930, 2507, 198, 197, 2722, 4347, 48, 510, 298, 50108, 61958, 284, 348, 58, 15, 921, 197, 11940, 510, 298, 853, 7013, 13080, 445, 11808, 3239, 1018, 82, 497, 595, 568, 2522, 19886, 69497, 340, 197, 197, 532, 197, 3733, 151666, 198, 151645, 198, 151644, 872, 271, 7039, 11, 3118, 389, 279, 3881, 2266, 1995, 11, 2550, 279, 1102, 5961, 438, 2567, 11, 25470, 2701, 279, 2550, 3561, 8502, 2041, 894, 5107, 2550, 13, 9258, 1969, 1191, 448, 366, 76303, 38283, 29, 323, 835, 448, 690, 76303, 38283, 14276, 3155, 537, 1618, 894, 5392, 624, 334, 98743, 1019, 12, 3155, 537, 2550, 2197, 429, 525, 279, 1852, 438, 279, 1482, 2197, 624, 12, 1416, 279, 2390, 374, 1602, 4285, 11, 279, 2197, 5944, 1265, 387, 438, 4285, 438, 3204, 624, 12, 1446, 27732, 5889, 304, 8453, 7, 104811, 8, 151645, 198, 151644, 77091, 198, 27, 76303, 38283, 397, 515, 220, 330, 3615, 788, 2278, 262, 341, 414, 330, 2102, 788, 330, 28648, 12, 29172, 23904, 756, 414, 330, 606, 788, 330, 35, 13659, 32112, 5333, 756, 414, 330, 37555, 2458, 788, 4383, 8092, 14, 8092, 4030, 37255, 18002, 497, 330, 22803, 14, 29172, 9199, 37255, 18002, 497, 330, 2740, 61764, 29172, 14, 42315, 82, 18002, 497, 330, 2740, 61764, 29172, 3470, 68164, 18002, 497, 330, 2740, 61764, 29172, 29581, 18002, 8097, 414, 330, 40581, 788, 330, 50377, 40549, 32112, 5333, 43589, 100700, 111116, 1773, 65577, 16447, 3366, 69162, 40549, 32112, 10130, 5333, 348, 17, 43589, 100873, 114288, 33071, 101884, 3837, 100630, 100811, 65101, 83751, 72225, 5373, 105713, 39352, 5373, 35112, 6567, 36548, 105679, 10130, 10236, 104, 107, 27442, 1773, 66833, 41932, 608, 85, 17, 14, 51461, 117, 78882, 27442, 5373, 14, 85, 17, 19632, 26539, 74866, 106, 23656, 78882, 27442, 5373, 14, 85, 17, 9388, 606, 4472, 42315, 82, 9388, 16291, 92, 97259, 227, 23990, 78882, 27442, 5373, 14, 85, 17, 9388, 606, 4472, 38145, 1279, 9388, 36339, 92, 23404, 10236, 104, 107, 27442, 5373, 14, 85, 17, 9388, 606, 4472, 38145, 1279, 29581, 14, 220, 52526, 78882, 27442, 9370, 101884, 1773, 66394, 103991, 78882, 27442, 9370, 10130, 81454, 9909, 3806, 5373, 6221, 5373, 2946, 5373, 14424, 5373, 35433, 64359, 3144, 6567, 44401, 28330, 5373, 34859, 14, 102808, 68805, 5373, 44091, 16476, 33108, 32100, 54542, 1773, 100630, 104510, 100674, 5373, 43815, 31905, 105493, 5373, 17177, 99922, 52526, 5373, 63789, 27442, 99448, 41683, 49567, 104112, 105539, 1773, 99553, 104621, 37029, 19793, 26355, 3837, 101987, 100007, 67338, 100142, 40549, 41479, 95, 17523, 78882, 57218, 16447, 3366, 220, 108221, 63276, 414, 330, 5864, 26564, 788, 330, 35, 13659, 32112, 5333, 54851, 105168, 9370, 107736, 101931, 3837, 43815, 101162, 102024, 100136, 100166, 104542, 3837, 102298, 101213, 108247, 78882, 27442, 101884, 3837, 77288, 71268, 106466, 103967, 9370, 25414, 18137, 96, 236, 33983, 3837, 104964, 46944, 111116, 15946, 100873, 53481, 55338, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 414, 330, 4648, 31206, 788, 330, 8996, 698, 262, 1153, 262, 341, 414, 330, 2102, 788, 330, 2245, 41387, 71312, 23904, 756, 414, 330, 606, 788, 330, 110195, 17881, 10130, 5333, 756, 414, 330, 37555, 2458, 788, 4383, 8092, 14, 8092, 4030, 37255, 18002, 497, 330, 8611, 34827, 4030, 37255, 18002, 497, 330, 50395, 14, 13131, 388, 2836, 37255, 18002, 497, 330, 5834, 21492, 76196, 4030, 37255, 18002, 497, 330, 22803, 18008, 4130, 4030, 37255, 18002, 8097, 414, 330, 40581, 788, 330, 50377, 16447, 3366, 38433, 226, 110195, 17881, 10130, 5333, 43589, 100873, 111116, 1773, 100700, 65577, 20713, 5373, 13298, 5373, 31133, 5373, 16219, 5373, 11066, 12, 1552, 44054, 93437, 104186, 101979, 104516, 107736, 1773, 100630, 20713, 43589, 608, 12120, 5373, 14, 65512, 5373, 14, 12885, 10236, 104, 107, 27442, 24968, 13298, 43589, 608, 12120, 5373, 14, 38145, 1279, 5373, 14, 6120, 10236, 104, 107, 27442, 24968, 31133, 43589, 608, 65512, 5373, 14, 2388, 19842, 5373, 14, 4059, 466, 824, 10236, 104, 107, 27442, 24968, 11066, 12, 1552, 43589, 608, 14082, 5373, 14, 81907, 5373, 14, 1826, 6295, 10236, 104, 107, 27442, 24968, 16219, 43589, 608, 29172, 14, 38188, 5373, 14, 22803, 5457, 16, 14, 29172, 4322, 1097, 2995, 10236, 104, 107, 27442, 1773, 32664, 103991, 5333, 220, 66833, 41932, 10130, 81454, 5373, 3144, 8908, 115, 107, 66569, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 16476, 33108, 37029, 102122, 1773, 66394, 110195, 106588, 47872, 11622, 100145, 5373, 20074, 114442, 33108, 32100, 54542, 100674, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 33108, 102705, 105866, 63276, 414, 330, 5864, 26564, 788, 330, 110195, 17881, 10130, 5333, 6567, 114, 231, 81217, 101213, 101970, 110195, 3837, 103991, 110195, 118755, 105071, 5333, 10236, 104, 107, 27442, 33108, 104559, 3837, 85106, 59879, 110195, 71817, 101348, 17177, 5122, 16810, 5333, 5373, 13298, 5333, 5373, 31133, 5333, 5373, 11066, 12, 1552, 5333, 5373, 16219, 5333, 3837, 105920, 100700, 66394, 103991, 110195, 106708, 107736, 63276, 414, 330, 4648, 31206, 788, 330, 14004, 756, 414, 330, 5864, 788, 2278, 286, 341, 688, 330, 2102, 788, 330, 8092, 71312, 23904, 756, 688, 330, 606, 788, 330, 16810, 10130, 5333, 756, 688, 330, 37555, 2458, 788, 4383, 8092, 14, 8092, 4030, 37255, 18002, 497, 330, 8092, 14, 8092, 2972, 25085, 18002, 8097, 688, 330, 40581, 788, 330, 100700, 65577, 20713, 44054, 93437, 9370, 10130, 5333, 46602, 98, 39426, 1773, 100630, 99722, 101071, 78882, 27442, 608, 12120, 5373, 47, 17, 47, 41479, 96, 51827, 78882, 27442, 608, 65512, 5373, 20074, 62189, 78882, 27442, 608, 12885, 5373, 44091, 51154, 78882, 27442, 49567, 1773, 66394, 103991, 78882, 27442, 106708, 101884, 3837, 100630, 10130, 81454, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 54542, 1773, 66833, 41932, 20713, 69372, 98749, 67338, 100001, 5333, 220, 106961, 110195, 108221, 3837, 100630, 69041, 40179, 62579, 49026, 5373, 45181, 17116, 58143, 92894, 20713, 39095, 27366, 20074, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 3837, 101987, 104621, 100007, 57218, 20713, 220, 108221, 1773, 65577, 20713, 64388, 117, 99996, 32665, 3837, 29524, 14397, 3034, 5373, 78882, 39426, 85767, 5373, 62189, 104238, 49567, 63276, 688, 330, 5864, 26564, 788, 330, 16810, 10130, 5333, 54851, 106215, 110195, 9370, 107736, 3837, 98380, 101162, 101096, 3837, 99558, 102074, 393, 17, 47, 39095, 27366, 33108, 99722, 101071, 3837, 43815, 16530, 102181, 3837, 46944, 111116, 106131, 102994, 55338, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 688, 330, 4648, 31206, 788, 330, 8996, 698, 286, 1153, 286, 341, 688, 330, 2102, 788, 330, 8611, 71312, 23904, 756, 688, 330, 606, 788, 330, 13298, 10130, 5333, 756, 688, 330, 37555, 2458, 788, 4383, 8611, 34827, 4030, 37255, 18002, 497, 330, 8611, 34827, 2972, 25085, 18002, 8097, 688, 330, 40581, 788, 330, 100700, 65577, 17116, 44054, 93437, 9370, 10130, 5333, 46602, 98, 39426, 1773, 100630, 99722, 101071, 78882, 27442, 608, 12120, 5373, 35112, 53497, 246, 99871, 78882, 27442, 608, 38145, 1279, 9388, 36339, 92, 5373, 52526, 54542, 78882, 27442, 608, 6120, 5373, 103414, 39352, 78882, 27442, 49567, 1773, 66394, 103991, 78882, 27442, 106708, 101884, 3837, 100630, 10130, 81454, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 54542, 1773, 66833, 41932, 17116, 69372, 98749, 100622, 105975, 92374, 99553, 20074, 47874, 3837, 100630, 23404, 53497, 246, 99871, 5373, 98671, 99658, 86312, 39352, 5373, 118878, 107101, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 3837, 101987, 100007, 52526, 33108, 62189, 23404, 62262, 1773, 65577, 17116, 64388, 117, 99996, 32665, 3837, 29524, 105653, 33447, 78882, 85767, 5373, 103414, 100786, 102498, 5373, 105173, 104238, 49567, 63276, 688, 330, 5864, 26564, 788, 330, 13298, 10130, 5333, 54851, 106215, 110195, 9370, 107736, 3837, 99558, 102074, 23404, 53497, 246, 99871, 33108, 103414, 39352, 3837, 98380, 101162, 102024, 3837, 46944, 111116, 73670, 100873, 53481, 55338, 78556, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 688, 330, 4648, 31206, 788, 330, 8996, 698, 286, 1153, 286, 341, 688, 330, 2102, 788, 330, 50395, 71312, 23904, 756, 688, 330, 606, 788, 330, 31133, 10130, 5333, 756, 688, 330, 37555, 2458, 788, 4383, 50395, 14, 13131, 388, 2836, 37255, 18002, 497, 330, 50395, 14, 65512, 2972, 25085, 18002, 497, 330, 50395, 90228, 466, 69, 509, 1451, 25085, 18002, 8097, 688, 330, 40581, 788, 330, 100700, 65577, 40179, 44054, 93437, 9370, 10130, 5333, 46602, 98, 39426, 1773, 100630, 99722, 101071, 78882, 27442, 608, 12120, 5373, 47, 17, 47, 41479, 96, 51827, 78882, 27442, 608, 65512, 5373, 32664, 49567, 92374, 99879, 78882, 27442, 608, 2388, 19842, 5373, 23305, 110042, 78882, 27442, 608, 4059, 466, 824, 5373, 100787, 27369, 78882, 27442, 49567, 1773, 66394, 103991, 78882, 27442, 106708, 101884, 3837, 100630, 10130, 81454, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 54542, 1773, 66833, 41932, 40179, 69372, 98749, 102020, 393, 17, 47, 10236, 121, 239, 68065, 3837, 100630, 32664, 49567, 92374, 39352, 5373, 64064, 17177, 28291, 5373, 71356, 100786, 102498, 101999, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 3837, 101987, 20713, 69372, 98749, 67338, 40179, 69425, 46451, 92894, 32664, 49567, 92374, 1773, 65577, 40179, 64388, 117, 99996, 32665, 3837, 29524, 32664, 49567, 92374, 104238, 5373, 71356, 85767, 5373, 102111, 104118, 49567, 63276, 688, 330, 5864, 26564, 788, 330, 31133, 10130, 5333, 54851, 106215, 110195, 9370, 107736, 3837, 99558, 102074, 393, 17, 47, 66521, 237, 47872, 33108, 32664, 49567, 92374, 39352, 3837, 98380, 101162, 101096, 3837, 46944, 111116, 73670, 102994, 55338, 78556, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 688, 330, 4648, 31206, 788, 330, 8996, 698, 286, 1153, 286, 341, 688, 330, 2102, 788, 330, 5834, 21492, 71312, 23904, 756, 688, 330, 606, 788, 330, 11066, 12, 1552, 10130, 5333, 756, 688, 330, 37555, 2458, 788, 4383, 5834, 21492, 76196, 4030, 37255, 18002, 497, 330, 5834, 21492, 76196, 2972, 25085, 18002, 8097, 688, 330, 40581, 788, 330, 100700, 65577, 7854, 12, 1552, 44054, 93437, 9370, 10130, 5333, 46602, 98, 39426, 1773, 100630, 99722, 101071, 78882, 27442, 608, 12120, 5373, 80158, 100829, 101071, 78882, 27442, 608, 878, 1880, 5373, 105151, 39352, 78882, 27442, 608, 14082, 9388, 4578, 92, 5373, 106871, 44177, 78882, 27442, 608, 81907, 9388, 23476, 4472, 14082, 5373, 105710, 105173, 78882, 27442, 608, 1826, 6295, 84460, 9388, 4578, 92, 5373, 13298, 80528, 78882, 27442, 608, 8611, 10236, 255, 231, 1773, 66394, 103991, 78882, 27442, 106708, 101884, 3837, 100630, 10130, 81454, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 54542, 1773, 66833, 41932, 7854, 12, 1552, 69372, 98749, 39352, 105151, 26939, 106208, 9370, 100261, 99759, 3837, 100630, 99960, 103414, 105173, 5373, 105537, 106637, 5373, 105151, 104738, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 3837, 101987, 100007, 51154, 33108, 50007, 105151, 27369, 1773, 65577, 7854, 12, 1552, 64388, 117, 99996, 32665, 3837, 29524, 105173, 104190, 5373, 105537, 104238, 5373, 105653, 85767, 49567, 63276, 688, 330, 5864, 26564, 788, 330, 11066, 12, 1552, 10130, 5333, 54851, 106215, 110195, 9370, 107736, 3837, 99558, 102074, 105151, 108069, 99960, 103414, 105173, 3837, 98380, 101162, 102024, 3837, 46944, 111116, 73670, 100873, 53481, 55338, 78556, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 688, 330, 4648, 31206, 788, 330, 8996, 698, 286, 1153, 286, 341, 688, 330, 2102, 788, 330, 22803, 71312, 23904, 756, 688, 330, 606, 788, 330, 16219, 10130, 5333, 756, 688, 330, 37555, 2458, 788, 4383, 22803, 18008, 4130, 4030, 37255, 18002, 497, 330, 22803, 14, 29172, 9199, 37255, 18002, 8097, 688, 330, 40581, 788, 330, 100700, 65577, 32778, 44054, 93437, 9370, 10130, 5333, 46602, 98, 39426, 1773, 100630, 99722, 101071, 78882, 27442, 608, 12120, 5373, 61689, 20742, 101151, 78882, 27442, 608, 29172, 14, 38188, 5373, 98841, 18158, 78882, 27442, 608, 22803, 5457, 16, 14, 29172, 4322, 1097, 2995, 5373, 106130, 100795, 78882, 27442, 608, 85, 17, 19632, 26539, 10236, 255, 231, 1773, 66394, 103991, 78882, 27442, 106708, 101884, 3837, 100630, 10130, 81454, 5373, 34859, 32665, 5373, 102808, 68805, 5373, 32100, 54542, 1773, 66833, 41932, 32778, 69372, 98749, 54542, 100811, 65101, 52526, 33108, 98841, 99259, 3837, 100630, 57621, 54542, 5373, 98841, 18158, 104238, 5373, 116817, 104190, 1773, 99553, 5333, 8908, 108, 225, 11622, 19793, 26355, 3837, 101987, 100007, 85767, 33108, 37029, 32778, 220, 47874, 1773, 65577, 32778, 64388, 117, 99996, 32665, 3837, 29524, 52526, 104238, 5373, 98841, 99259, 85767, 5373, 57621, 108459, 49567, 63276, 688, 330, 5864, 26564, 788, 330, 16219, 10130, 5333, 54851, 106215, 110195, 9370, 107736, 3837, 99558, 102074, 100811, 65101, 52526, 101259, 33108, 98841, 99259, 98380, 3837, 98380, 101162, 101096, 3837, 46944, 111116, 73670, 102994, 55338, 78556, 78882, 27442, 3837, 106431, 100642, 101348, 17177, 63276, 688, 330, 4648, 31206, 788, 330, 8996, 698, 286, 456, 414, 5133, 262, 1153, 262, 341, 414, 330, 2102, 788, 330, 79, 17, 79, 12, 17014, 23904, 756, 414, 330, 606, 788, 330, 47, 17, 47, 66521, 237, 96422, 107736, 756, 414, 330, 37555, 2458, 788, 4383, 15110, 4322, 17, 79, 4322, 17, 79, 57322, 497, 330, 2740, 5523, 48709, 2687, 15222, 14, 5148, 14, 5148, 18002, 497, 330, 2740, 5523, 48709, 2687, 15222, 14, 5148, 7530, 437, 927, 4407, 18002, 497, 330, 2740, 5523, 48709, 2687, 15222, 14, 5148, 50624, 18002, 8097, 414, 330, 40581, 788, 330, 50377, 393, 17, 47, 66521, 237, 96422, 107736, 9370, 100700, 111116, 1773, 104210, 18433, 4322, 17, 79, 4322, 17, 79, 57322, 41479, 248, 64559, 3837, 66833, 41932, 16447, 3366, 393, 17, 47, 10236, 121, 239, 68065, 105778, 68805, 5373, 64064, 101136, 5373, 118376, 102054, 33108, 20074, 107468, 100674, 1773, 100630, 101136, 118098, 23836, 64205, 91282, 5373, 40820, 41299, 43316, 20074, 116226, 68805, 5373, 64064, 100641, 33108, 101999, 102054, 1773, 66394, 8536, 29661, 5373, 8344, 2566, 5373, 12116, 5373, 1900, 5373, 31209, 5373, 9269, 10236, 255, 231, 64205, 109963, 100166, 33108, 105795, 1773, 66833, 41932, 32664, 49567, 92374, 106588, 104516, 101136, 3837, 100630, 64064, 44091, 39352, 5373, 20074, 34859, 104238, 5373, 104242, 100359, 100674, 1773, 99553, 101136, 101884, 19793, 26355, 3837, 101987, 100007, 100641, 393, 17, 47, 32181, 252, 29077, 33108, 107468, 20074, 1773, 65577, 101136, 71109, 39352, 5373, 114288, 33071, 101882, 33108, 99464, 101118, 63276, 414, 330, 5864, 26564, 788, 330, 47, 17, 47, 66521, 237, 96422, 107736, 20412, 104210, 69634, 41479, 248, 64559, 9370, 40820, 41299, 43316, 101136, 3837, 43815, 101162, 102024, 100136, 100166, 32108, 3837, 99558, 102074, 64205, 68805, 33108, 104516, 102054, 3837, 46944, 111116, 73670, 100873, 53481, 101136, 101931, 3837, 106431, 100642, 101348, 17177, 63276, 414, 330, 4648, 31206, 788, 330, 8996, 698, 262, 1153, 262, 341, 414, 330, 2102, 788, 330, 12120, 54785, 23904, 756, 414, 330, 606, 788, 330, 99722, 101071, 57218, 39352, 5333, 756, 414, 330, 37555, 2458, 788, 4383, 2740, 14, 12120, 2028, 14, 32225, 18002, 497, 330, 2740, 14, 12120, 2028, 46619, 261, 18002, 497, 330, 6031, 7530, 5252, 7530, 5252, 18002, 497, 330, 2740, 3183, 11603, 3183, 11603, 18002, 8097, 414, 330, 40581, 788, 330, 50377, 99722, 101071, 57218, 39352, 5333, 43589, 100700, 111116, 1773, 65577, 55338, 110195, 9370, 99722, 101071, 78882, 27442, 3837, 100630, 105266, 33071, 101071, 608, 12120, 5373, 80158, 100829, 33071, 101071, 608, 878, 1880, 5373, 104118, 107047, 107736, 608, 43262, 5373, 110760, 78882, 27442, 608, 8349, 87146, 299, 69, 10236, 255, 231, 1773, 66394, 99722, 101071, 9370, 101884, 100674, 3837, 100630, 101966, 101071, 5373, 107285, 101071, 5373, 108459, 31548, 33108, 104814, 104238, 1773, 66833, 41932, 39352, 5333, 43589, 98380, 3837, 100630, 104299, 85767, 50007, 5373, 44091, 51154, 5373, 113308, 40090, 107736, 1773, 99553, 104814, 102705, 105866, 3837, 101987, 100007, 57218, 97137, 5373, 44, 18, 10236, 255, 231, 104814, 72448, 102705, 1773, 65577, 32100, 54542, 5373, 71304, 13343, 85767, 5373, 29258, 41321, 104238, 49567, 113308, 78556, 32665, 1773, 100630, 105716, 105853, 33108, 102111, 47872, 90172, 9370, 5333, 85658, 39907, 63276, 414, 330, 5864, 26564, 788, 330, 99722, 101071, 57218, 39352, 5333, 6567, 114, 113, 99758, 34187, 105743, 104814, 5373, 113308, 33108, 39352, 98380, 3837, 43815, 101162, 103967, 3837, 99558, 102074, 99722, 44091, 33108, 72448, 39352, 3837, 46944, 111116, 73670, 100873, 53481, 55338, 78556, 107736, 3837, 106431, 100642, 101348, 17177, 63276, 414, 330, 4648, 31206, 788, 330, 8996, 698, 262, 456, 220, 5133, 532, 522, 76303, 38283, 29, 151645, 198], "loss_mask": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]} diff --git a/progress/github/SpecForge/tests/test_layers/__init__.py b/progress/github/SpecForge/tests/test_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_layers/test_decoder.py b/progress/github/SpecForge/tests/test_layers/test_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..db6d785db9291e7d981f1c5f67424067cdf0d27d --- /dev/null +++ b/progress/github/SpecForge/tests/test_layers/test_decoder.py @@ -0,0 +1,395 @@ +import os +import time +import unittest + +import torch +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from torch import nn +from transformers import PretrainedConfig +from yunchang import EXTRACT_FUNC_DICT + +from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter +from specforge.data.preprocessing import build_offline_eagle3_dataset + +# Project-specific imports +from specforge.distributed import destroy_distributed, init_distributed +from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer +from specforge.utils import padding +from tests.utils import get_available_port + + +def get_model_config(): + """Create and return the model configuration.""" + config_dict = { + "architectures": ["LlamaForCausalLMEagle3"], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [1, 29, 57], + "use_aux_hidden_state": True, + }, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 29568, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": True, + "rope_scaling": None, + "vocab_size": 129280, + "draft_vocab_size": 32000, + "pretraining_tp": 1, + } + return PretrainedConfig.from_dict(config_dict) + + +def setup_env(rank, world_size, port): + """Set up distributed environment variables.""" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + torch.cuda.set_device(rank) + + +def dbg(rank, msg): + print(f"[rank{rank}] {msg}", flush=True) + + +def wait_for_file(path, timeout_s=60, poll_s=0.1): + start = time.time() + while time.time() - start < timeout_s: + if os.path.exists(path): + return True + time.sleep(poll_s) + return False + + +def run_iterative_pass( + decoder_layer, + embed_tokens, + input_ids, + hidden_states, + attention_mask, + position_ids, + ttt_length, +): + """ + Core loop: execute the forward pass `ttt_length` times. + Used for both Golden (SDPA) and Distributed (USP) runs to ensure logic consistency. + """ + # Clone to avoid side effects on original tensors + curr_input_ids = input_ids.clone() + curr_hidden_states = hidden_states.clone() + + # Init cache + cache_hidden = [[], []] + past_key_values = None + final_output = None + + for idx in range(ttt_length): + is_last = idx == ttt_length - 1 + + # 1. Embed inputs + inputs_embeds = embed_tokens(curr_input_ids).to(curr_hidden_states.dtype) + + # 2. Forward pass + output_hidden_states = decoder_layer( + input_emb=inputs_embeds, + hidden_states=curr_hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=False, + ) + + # Update states for next iteration + curr_hidden_states = output_hidden_states + final_output = output_hidden_states + + # 3. Simulate TTT padding/shift + if not is_last: + curr_input_ids = padding(curr_input_ids, left=False) + + return final_output + + +def run_test_case(rank, world_size, port): + """Worker function executed in each process.""" + setup_env(rank, world_size, port) + device = torch.device(f"cuda:{rank}") + set_seed(42) + dbg(rank, "env setup complete") + + # --- Data & Config Preparation --- + config = get_model_config() + seq_len = 1560 + batch_size = 1 + ttt_length = 3 + + # Generate dummy data on GPU + data_input_ids = torch.randint(0, 10000, (batch_size, seq_len), device=device) + data_hidden_states = torch.randn( + batch_size, seq_len, config.hidden_size, device=device, dtype=torch.bfloat16 + ) + attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).view( + 1, 1, seq_len, seq_len + ) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + + # Shared embedding layer + embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id + ).to(device) + + # --- Phase 1: Golden Run (FA) --- + # Init dist briefly for internal checks, even if running single-device logic + init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1) + dbg(rank, "init_distributed (FA) done") + + sdpa_decoder = ( + LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16) + ) + dbg(rank, "FA decoder created") + # Adapter smoke test for FA/SDPA-style path + dummy_model = type("Dummy", (), {})() + sdpa_adapter = SdpaLikeAdapter(dummy_model) + sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32) + sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32) + sdpa_state = sdpa_adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=data_input_ids, + attention_mask=attention_mask, + loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32), + position_ids=position_ids, + hidden_states=data_hidden_states, + target_p_padded=sdpa_target_p, + position_mask=sdpa_position_mask, + seq_length=seq_len, + ) + assert sdpa_state.input_ids.shape[1] == seq_len + assert sdpa_state.hidden_states.shape[1] == seq_len + + with torch.no_grad(): + sdpa_output = run_iterative_pass( + decoder_layer=sdpa_decoder, + embed_tokens=embed_tokens, + input_ids=data_input_ids, + hidden_states=data_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ttt_length=ttt_length, + ) + dbg(rank, "FA forward done") + + # Save weights for alignment and cleanup SDPA model + state_dict = sdpa_decoder.state_dict() + del sdpa_decoder + destroy_distributed() + dbg(rank, "destroy_distributed (FA) done") + + # --- Phase 2: Distributed Run (USP) --- + def subtest_usp(sp_ulysses_degree, sp_ring_degree): + """Run USP with specific topology and compare against Golden.""" + try: + init_distributed( + tp_size=1, + sp_ulysses_size=sp_ulysses_degree, + sp_ring_size=sp_ring_degree, + ) + dbg( + rank, + f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done", + ) + # Dataset + adapter smoke test (USP path) + tmp_dir = "./tmp/usp_dataset_shared" + try: + if rank == 0: + os.makedirs(tmp_dir, exist_ok=True) + sample = { + "input_ids": data_input_ids[0].cpu(), + "loss_mask": torch.ones_like(data_input_ids[0].cpu()), + "hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + "aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + } + torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt")) + dbg(rank, "wrote sample ckpt") + ready_flag = os.path.join(tmp_dir, "ready.flag") + with open(ready_flag, "w", encoding="utf-8") as f: + f.write("ready\n") + if rank != 0: + ready_flag = os.path.join(tmp_dir, "ready.flag") + assert wait_for_file( + ready_flag, timeout_s=60 + ), "timeout waiting for ready flag" + dbg(rank, "dataset sync done") + assert os.path.exists( + os.path.join(tmp_dir, "data_0.ckpt") + ), f"Expected sample not found at {tmp_dir}" + dbg(rank, "sample exists") + + ds = build_offline_eagle3_dataset( + tmp_dir, + max_len=seq_len, + ttt_length=ttt_length, + use_usp_preprocess=True, + ) + dbg(rank, "dataset built") + item = ds[0] + dbg(rank, "dataset item loaded") + assert "position_ids" in item + + dummy_model = type("Dummy", (), {})() + adapter = UspAdapter(dummy_model) + local_seq_len = item["input_ids"].shape[1] + target_p_padded = torch.zeros( + (1, local_seq_len, 8), device=device, dtype=torch.float32 + ) + position_mask = torch.ones( + (1, local_seq_len, 1), device=device, dtype=torch.float32 + ) + state = adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=item["input_ids"].to(device), + attention_mask=item["attention_mask"].to(device), + loss_mask=item["loss_mask"].to(device).unsqueeze(-1), + position_ids=item["position_ids"].to(device), + hidden_states=item["hidden_state"].to(device), + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=local_seq_len, + ) + assert state.input_ids.shape[1] == local_seq_len - ttt_length + assert state.hidden_states.shape[1] == local_seq_len - ttt_length + dbg(rank, "adapter step_view ok") + finally: + if rank == 0: + done_flag = os.path.join(tmp_dir, "done.flag") + assert wait_for_file( + done_flag, timeout_s=60 + ), "timeout waiting for done flag" + try: + for root, _, files in os.walk(tmp_dir): + for name in files: + os.remove(os.path.join(root, name)) + os.rmdir(tmp_dir) + except OSError: + pass + else: + done_flag = os.path.join(tmp_dir, "done.flag") + with open(done_flag, "w", encoding="utf-8") as f: + f.write("done\n") + + # Init USP model and load golden weights + usp_decoder = ( + LlamaDecoderLayer(config, attention_backend="usp") + .to(device) + .to(torch.bfloat16) + ) + usp_decoder.load_state_dict(state_dict) + dbg(rank, "USP decoder loaded") + + # Shard data (Split Input) + extract_func = EXTRACT_FUNC_DICT["basic"] + + local_input_ids = ( + extract_func( + data_input_ids, + rank, + world_size=world_size, + rd=sp_ring_degree, + ud=sp_ulysses_degree, + ) + .detach() + .clone() + ) + + local_hidden_states = ( + extract_func( + data_hidden_states, + rank, + world_size=world_size, + rd=sp_ring_degree, + ud=sp_ulysses_degree, + ) + .detach() + .clone() + ) + dbg(rank, "USP local inputs prepared") + total_degree = sp_ring_degree * sp_ulysses_degree + chunk_size = sdpa_output.shape[1] // total_degree + start_idx = (rank % total_degree) * chunk_size + local_len = local_input_ids.shape[1] + local_position_ids = ( + torch.arange(start_idx, start_idx + local_len, device=device) + .unsqueeze(0) + .long() + ) + local_attention_mask = torch.tril( + torch.ones(local_len, local_len, device=device) + ).view(1, 1, local_len, local_len) + + # Run USP forward + if sp_ring_degree > 1: + usp_attention_mask = local_attention_mask + usp_position_ids = local_position_ids + else: + usp_attention_mask = attention_mask + usp_position_ids = position_ids + with torch.no_grad(): + usp_output = run_iterative_pass( + decoder_layer=usp_decoder, + embed_tokens=embed_tokens, + input_ids=local_input_ids, + hidden_states=local_hidden_states, + attention_mask=usp_attention_mask, + position_ids=usp_position_ids, + ttt_length=ttt_length, + ) + dbg(rank, "USP forward done") + + # Verify results + # Slice the golden output to match the current rank's chunk + end_idx = start_idx + chunk_size + + golden_chunk = sdpa_output[:, start_idx:end_idx, :] + + assert torch.allclose(usp_output, golden_chunk, rtol=2e-2, atol=2e-2), ( + f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n" + f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}" + ) + dbg(rank, "USP output verified") + + finally: + destroy_distributed() + dbg(rank, "destroy_distributed (USP) done") + + # Case 1: Hybrid (Ulysses=2, Ring=1) + subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1) + + # Case 2: Hybrid (Ulysses=1, Ring=2) + subtest_usp(sp_ulysses_degree=1, sp_ring_degree=2) + + +class TestTTTDistributed(unittest.TestCase): + def test_llama_usp_decoder(self): + world_size = 2 + port = get_available_port() + mp.spawn(run_test_case, nprocs=world_size, args=(world_size, port)) + + +if __name__ == "__main__": + unittest.main() diff --git a/progress/github/SpecForge/tests/test_layers/test_embedding.py b/progress/github/SpecForge/tests/test_layers/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..fc16cbcbbc95cda35d68be50797e12c5b3a4c9de --- /dev/null +++ b/progress/github/SpecForge/tests/test_layers/test_embedding.py @@ -0,0 +1,75 @@ +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed + +from specforge.distributed import init_distributed +from specforge.layers import VocabParallelEmbedding +from tests.utils import get_available_port + + +def run_embedding(rank, world_size, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + init_distributed(tp_size=world_size) + set_seed(42) + + # =============================== + # Case 1: vocab size is divisible by the TP size + # =============================== + # create layers + data = torch.randint(0, 512, (1, 128)).long().cuda() + native_embedding = torch.nn.Embedding(512, 256, padding_idx=314).cuda() + sf_embedding = VocabParallelEmbedding(512, 256, padding_idx=314).cuda() + sf_embedding.load_state_dict(native_embedding.state_dict()) + + # forward + native_output = native_embedding(data) + sf_output = sf_embedding(data) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + # =============================== + # Case 2: vocab size is NOT divisible by the TP size + # =============================== + # create layers + data = torch.randint(0, 355, (1, 128)).long().cuda() + native_embedding = torch.nn.Embedding(355, 256, padding_idx=314).cuda() + sf_embedding = VocabParallelEmbedding(355, 256, padding_idx=314).cuda() + sf_embedding.load_state_dict(native_embedding.state_dict()) + + # forward + native_output = native_embedding(data) + sf_output = sf_embedding(data) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + dist.destroy_process_group() + + +class TestEmbedding(unittest.TestCase): + + def test_embedding(self): + port = get_available_port() + mp.spawn(run_embedding, nprocs=2, args=(2, port)) + + port = get_available_port() + mp.spawn(run_embedding, nprocs=1, args=(1, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestEmbedding)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_layers/test_linear.py b/progress/github/SpecForge/tests/test_layers/test_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec5706958c6a03fe341f142f9139084583a1b66 --- /dev/null +++ b/progress/github/SpecForge/tests/test_layers/test_linear.py @@ -0,0 +1,147 @@ +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed + +from specforge.distributed import gather_tensor, get_tp_group, init_distributed +from specforge.layers import ColumnParallelLinear, RowParallelLinear +from tests.utils import get_available_port + + +def run_column_parallel_linear(rank, world_size, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + init_distributed(tp_size=world_size) + set_seed(42) + + # =============================== + # Case 1: normal layout + # =============================== + # create data + data = torch.rand(1, 256).cuda() + + # create layers + native_linear = torch.nn.Linear(256, 512).cuda() + sf_linear = ColumnParallelLinear(256, 512, layout_type="normal").cuda() + sf_linear.load_state_dict(native_linear.state_dict()) + + # forward + native_output = native_linear(data) + sf_output = sf_linear(data) + full_sf_output = gather_tensor(sf_output, get_tp_group()) + + # check + assert torch.allclose( + native_output, full_sf_output, rtol=1e-5, atol=1e-5 + ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + # =============================== + # Case 2: merged QKV layout + # =============================== + # create data + data = torch.rand(1, 256 * 3).cuda() + + # create layers + native_linear = torch.nn.Linear(256 * 3, 512).cuda() + sf_linear = ColumnParallelLinear(256 * 3, 512, layout_type="merged_qkv").cuda() + sf_linear.load_state_dict(native_linear.state_dict()) + + # forward + q, k, v = native_linear(data).chunk(3, dim=1) + sf_q, sf_k, sf_v = sf_linear(data).chunk(3, dim=1) + full_sf_q = gather_tensor(sf_q, get_tp_group()) + full_sf_k = gather_tensor(sf_k, get_tp_group()) + full_sf_v = gather_tensor(sf_v, get_tp_group()) + + # check + assert torch.allclose( + q, full_sf_q, rtol=1e-5, atol=1e-5 + ), f"q: \n{q}, \nfull_sf_q: \n{full_sf_q}" + assert torch.allclose( + k, full_sf_k, rtol=1e-5, atol=1e-5 + ), f"k: \n{k}, \nfull_sf_k: \n{full_sf_k}" + assert torch.allclose( + v, full_sf_v, rtol=1e-5, atol=1e-5 + ), f"v: \n{v}, \nfull_sf_v: \n{full_sf_v}" + + # =============================== + # Case 3: gate_up layout + # =============================== + # create data + data = torch.rand(1, 256 * 2).cuda() + + # create layers + native_linear = torch.nn.Linear(256 * 2, 512).cuda() + sf_linear = ColumnParallelLinear(256 * 2, 512, layout_type="gate_up").cuda() + sf_linear.load_state_dict(native_linear.state_dict()) + + # forward + gate, up = native_linear(data).chunk(2, dim=1) + sf_gate, sf_up = sf_linear(data).chunk(2, dim=1) + full_sf_gate = gather_tensor(sf_gate, get_tp_group()) + full_sf_up = gather_tensor(sf_up, get_tp_group()) + + # check + assert torch.allclose( + gate, full_sf_gate, rtol=1e-5, atol=1e-5 + ), f"gate: \n{gate}, \nfull_sf_gate: \n{full_sf_gate}" + assert torch.allclose( + up, full_sf_up, rtol=1e-5, atol=1e-5 + ), f"up: \n{up}, \nfull_sf_up: \n{full_sf_up}" + + dist.destroy_process_group() + + +def run_row_parallel_linear(rank, world_size, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + init_distributed(tp_size=world_size) + set_seed(42) + + # =============================== + # Case 1: normal layout + # the data in an parallel input, i.g. + # Y = AllReduce(X_i * W_i) + # =============================== + # create data + data = torch.rand(1, 256).cuda() + + # create layers + native_linear = torch.nn.Linear(256, 512).cuda() + sf_linear = RowParallelLinear(256, 512, layout_type="normal").cuda() + sf_linear.load_state_dict(native_linear.state_dict()) + + # forward + native_output = native_linear(data) + sf_output = sf_linear(data.chunk(world_size, dim=0)[rank]) + dist.all_reduce(sf_output, op=dist.ReduceOp.SUM, group=get_tp_group()) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"native_output: \n{native_output}, \nfull_sf_output: \n{full_sf_output}" + + +class TestLinear(unittest.TestCase): + + def test_column_parallel_linear(self): + port = get_available_port() + mp.spawn(run_column_parallel_linear, nprocs=2, args=(2, port)) + + def test_column_parallel_linear(self): + port = get_available_port() + mp.spawn(run_column_parallel_linear, nprocs=1, args=(1, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLinear)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_layers/test_lm_head.py b/progress/github/SpecForge/tests/test_layers/test_lm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..948ae5fdcc74108575c64a1a4393a330429dbb77 --- /dev/null +++ b/progress/github/SpecForge/tests/test_layers/test_lm_head.py @@ -0,0 +1,108 @@ +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed + +from specforge.distributed import init_distributed +from specforge.layers import ParallelLMHead, VocabParallelEmbedding +from tests.utils import get_available_port + + +def run_lm_head(rank, world_size, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + init_distributed(tp_size=world_size) + set_seed(42) + + # =============================== + # Case 1: the output vocab size is divisible by the TP size + # =============================== + # create data + data = torch.rand(1, 128, 256).cuda() + + for bias in [True, False]: + # create layers + native_lm_head = torch.nn.Linear(256, 512, bias=bias).cuda() + sf_lm_head = ParallelLMHead(256, 512, bias=bias).cuda() + sf_lm_head.load_state_dict(native_lm_head.state_dict()) + + # forward + native_output = native_lm_head(data) + sf_output = sf_lm_head(data, gather_output=True) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + # =============================== + # Case 2: the output vocab size is not divisible by the TP size + # =============================== + # create data + data = torch.rand(1, 128, 256).cuda() + + # create layers + native_lm_head = torch.nn.Linear(256, 377, bias=bias).cuda() + sf_lm_head = ParallelLMHead(256, 377, bias=bias).cuda() + sf_lm_head.load_state_dict(native_lm_head.state_dict()) + + # forward + native_output = native_lm_head(data) + sf_output = sf_lm_head(data, gather_output=True) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + # =============================== + # Case 3: tie word embedding + # =============================== + if not bias: + # there is no bias in the embedding layer so we skip when bias is True + # create data + data = torch.rand(128, 256).cuda() + + # create native layers + native_embedding = torch.nn.Embedding(512, 256).cuda() + native_lm_head = torch.nn.Linear(256, 512, bias=bias).cuda() + native_lm_head.weight = native_embedding.weight + + # create specforge layers + sf_embedding = VocabParallelEmbedding(512, 256).cuda() + sf_embedding.load_state_dict(native_embedding.state_dict()) + sf_lm_head = ParallelLMHead(256, 512, bias=bias).cuda() + sf_lm_head.weight = sf_embedding.weight + + # forward + native_output = native_lm_head(data) + sf_output = sf_lm_head(data, gather_output=True) + + # check + assert torch.allclose( + native_output, sf_output, rtol=1e-5, atol=1e-5 + ), f"bias: {bias}, native_output: \n{native_output}, \nsf_output: \n{sf_output}" + + dist.destroy_process_group() + + +class TestLMHead(unittest.TestCase): + + def test_lm_head(self): + port = get_available_port() + mp.spawn(run_lm_head, nprocs=2, args=(2, port)) + + port = get_available_port() + mp.spawn(run_lm_head, nprocs=1, args=(1, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLMHead)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_draft/__init__.py b/progress/github/SpecForge/tests/test_modeling/test_draft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_modeling/test_draft/test_llama3.py b/progress/github/SpecForge/tests/test_modeling/test_draft/test_llama3.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fa86c80b0b216a7546bca523ee8546f7a16d1f --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_draft/test_llama3.py @@ -0,0 +1,137 @@ +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch + +import torch +from transformers import LlamaConfig + +from specforge.modeling.draft.llama3_eagle import ( + LlamaAttention, + LlamaForCausalLMEagle3, + LlamaMLP, + LlamaRMSNorm, +) + +# from model_module import LlamaForCausalLMEagle3 + + +class TestLlamaForCausalLMEagle3Loading(unittest.TestCase): + + def setUp(self): + """Set up the test environment before each test.""" + self.temp_dir = tempfile.mkdtemp() + + config_dict = { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_key_value_heads": 8, + "num_hidden_layers": 1, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.28.1", + "use_cache": True, + "vocab_size": 128256, + "draft_vocab_size": 32000, + } + + self.config = LlamaConfig(**config_dict) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_model_initialization(self): + model = LlamaForCausalLMEagle3(self.config) + + self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) + self.assertIsInstance(model.midlayer.mlp, LlamaMLP) + self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) + self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) + self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) + self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size) + + def test_save_pretrained(self): + """Test the model's save_pretrained functionality.""" + model = LlamaForCausalLMEagle3(self.config) + + self.config.save_pretrained(self.temp_dir) + + model_path = os.path.join(self.temp_dir, "pytorch_model.bin") + torch.save(model.state_dict(), model_path) + + self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.json"))) + self.assertTrue(os.path.exists(model_path)) + + @patch("transformers.modeling_utils.PreTrainedModel.from_pretrained") + def test_from_pretrained_mock(self, mock_from_pretrained): + """mock""" + mock_model = LlamaForCausalLMEagle3(self.config) + mock_from_pretrained.return_value = mock_model + + loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir) + mock_from_pretrained.assert_called_once_with(self.temp_dir) + self.assertIsInstance(loaded_model, LlamaForCausalLMEagle3) + + def test_model_forward_pass(self): + """forward""" + model = LlamaForCausalLMEagle3(self.config) + model.eval() + + batch_size = 2 + seq_len = 10 + + input_emb = torch.randn(batch_size, seq_len, self.config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size * 3) + attention_mask = torch.ones(batch_size, seq_len) + + with torch.no_grad(): + outputs = model( + inputs_embeds=input_emb, + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size)) + + def test_state_dict_compatibility(self): + model1 = LlamaForCausalLMEagle3(self.config) + model2 = LlamaForCausalLMEagle3(self.config) + + state_dict = model1.state_dict() + + model2.load_state_dict(state_dict) + + for name, param1 in model1.named_parameters(): + param2 = dict(model2.named_parameters())[name] + self.assertTrue(torch.equal(param1, param2)) + + def test_config_validation(self): + invalid_config = LlamaConfig( + vocab_size=1000, + hidden_size=127, + num_attention_heads=4, + num_key_value_heads=2, + ) + + with self.assertRaises(AttributeError): + LlamaForCausalLMEagle3(invalid_config) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + + suite.addTest(unittest.makeSuite(TestLlamaForCausalLMEagle3Loading)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/__init__.py b/progress/github/SpecForge/tests/test_modeling/test_target/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/__init__.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_gpt_oss.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..c62521b5dc06791a06af82be84673f6d3dc1e6bd --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_gpt_oss.py @@ -0,0 +1,107 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import GptOssConfig, GptOssForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.gpt_oss import ( + GptOssForCausalLM as DistGptOssForCausalLM, +) +from tests.utils import get_available_port + + +def test_gpt_oss_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = GptOssConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + intermediate_size_mlp=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + num_local_experts=4, + tie_word_embeddings=tie_word_embeddings, + initializer_range=0.02, + hidden_act="silu", + layer_types=[ + "sliding_attention", + "full_attention", + ], + ) + + # create the single-gpu + model = GptOssForCausalLM(config).cuda().eval() + + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model = DistGptOssForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestGptOssTP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_gpt_oss_tp(self): + port = get_available_port() + mp.spawn(test_gpt_oss_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestGptOssTP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama4_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama4_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..1cea91f3b91d5de42f312b11351d1c05815b4632 --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama4_tp.py @@ -0,0 +1,104 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import Llama4ForCausalLM as HFLlama4ForCausalLM +from transformers import Llama4TextConfig + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.llama4 import ( + Llama4ForCausalLM as SFLlama4ForCausalLM, +) +from tests.utils import get_available_port + + +def test_llama4_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=world_size) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = Llama4TextConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + intermediate_size_mlp=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=10, + num_key_value_heads=2, + head_dim=64, + num_local_experts=4, + tie_word_embeddings=tie_word_embeddings, + initializer_range=0.02, + hidden_act="silu", + ) + + # create the single-gpu + model = HFLlama4ForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + dist_model = SFLlama4ForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestLlama4TP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_llama4_tp(self): + port = get_available_port() + mp.spawn(test_llama4_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLlama4TP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..857c002dee178f69005e311e8d168b98f9912db2 --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_llama_tp.py @@ -0,0 +1,100 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import LlamaConfig +from transformers import LlamaForCausalLM as HFLLamaForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.llama import ( + LlamaForCausalLM as SFLlamaForCausalLM, +) +from tests.utils import get_available_port + + +def test_llama3_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = LlamaConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=10, + num_key_value_heads=2, + initializer_range=0.02, + hidden_act="silu", + rms_norm_eps=1e-6, + tie_word_embeddings=tie_word_embeddings, + ) + + # create the single-gpu + model = HFLLamaForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + dist_model = SFLlamaForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestLlama3TP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_llama3_tp(self): + port = get_available_port() + mp.spawn(test_llama3_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestLlama3TP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_phi3_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_phi3_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9791d9d07ce4e369dd8570cb30a456dea038e0 --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_phi3_tp.py @@ -0,0 +1,106 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers.models.phi3 import Phi3Config +from transformers.models.phi3 import Phi3ForCausalLM as HFPhi3ForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.phi3 import ( + Phi3ForCausalLM as SFLPhi3ForCausalLM, +) +from tests.utils import get_available_port + + +def test_phi3_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = Phi3Config( + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=2, + max_position_embeddings=4096, + num_attention_heads=32, + num_key_value_heads=32, + tie_word_embeddings=tie_word_embeddings, + hidden_act="silu", + rms_norm_eps=1e-6, + attention_dropout=0.0, + resid_pdrop=0.0, + ) + + # create a simple single-gpu model + model = HFPhi3ForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + dist_model = SFLPhi3ForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + # Run inference on both models + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-4, + atol=1e-4, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestPhi3TP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_phi3_tp(self): + port = get_available_port() + mp.spawn(test_phi3_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + + suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPhi3TP)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen2_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen2_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..4581c3044f1f5d61c97f0789289c02ee424ad4fb --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen2_tp.py @@ -0,0 +1,105 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import Qwen2Config +from transformers import Qwen2ForCausalLM as HFWen2ForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.qwen2 import ( + Qwen2ForCausalLM as SFLQwen2ForCausalLM, +) +from tests.utils import get_available_port + + +def test_qwen2_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = Qwen2Config( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + intermediate_size_mlp=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=10, + num_key_value_heads=2, + head_dim=64, + num_local_experts=4, + tie_word_embeddings=tie_word_embeddings, + initializer_range=0.02, + hidden_act="silu", + ) + + # create the single-gpu + model = HFWen2ForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model = SFLQwen2ForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestQwen2TP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_qwen2_tp(self): + port = get_available_port() + mp.spawn(test_qwen2_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestQwen2TP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_moe_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_moe_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5f6450b269c9d39d396bff1cb8e935e52b4d1b --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_moe_tp.py @@ -0,0 +1,112 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers.models.qwen3_moe import Qwen3MoeConfig +from transformers.models.qwen3_moe import Qwen3MoeForCausalLM as HFWen3MoeForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.qwen3_moe import ( + Qwen3MoeForCausalLM as SFLQwen3MoeForCausalLM, +) +from tests.utils import get_available_port + + +def test_qwen3_moe_tp(rank, world_size, temp_dir, port, num_heads, num_kv_heads): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = Qwen3MoeConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + moe_intermediate_size=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=num_heads, + num_key_value_heads=num_kv_heads, + num_experts=64, + num_experts_per_tok=8, + hidden_act="silu", + rms_norm_eps=1e-6, + tie_word_embeddings=tie_word_embeddings, + ) + + # create a simple single-gpu model + model = HFWen3MoeForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model = SFLQwen3MoeForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestQwen3MoeTP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_qwen3_moe_tp_no_kv_head_replicas(self): + # Set to 2 as only 2 GPU avaialble in CI + port = get_available_port() + mp.spawn(test_qwen3_moe_tp, nprocs=2, args=(2, self.temp_dir.name, port, 8, 4)) + + def test_qwen3_moe_tp_kv_head_replicas(self): + port = get_available_port() + mp.spawn(test_qwen3_moe_tp, nprocs=2, args=(2, self.temp_dir.name, port, 8, 1)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + + suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestQwen3MoeTP)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_tp.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..eb447ec4567eca8d39b760e2542f70596c64b20b --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_custom_backend/test_qwen3_tp.py @@ -0,0 +1,106 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers.models.qwen3 import Qwen3Config +from transformers.models.qwen3 import Qwen3ForCausalLM as HFQwen3ForCausalLM + +from specforge.distributed import init_distributed +from specforge.modeling.target.custom_backend.qwen3 import ( + Qwen3ForCausalLM as SFLQwen3ForCausalLM, +) +from tests.utils import get_available_port + + +def test_qwen3_tp(rank, world_size, temp_dir, port): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=2) + set_seed(42) + + for tie_word_embeddings in [True, False]: + config = Qwen3Config( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + moe_intermediate_size=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=8, + num_key_value_heads=4, + hidden_act="silu", + rms_norm_eps=1e-6, + tie_word_embeddings=tie_word_embeddings, + ) + + # create a simple single-gpu model + model = HFQwen3ForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model = SFLQwen3ForCausalLM.from_pretrained(temp_dir).cuda() + dist.barrier() + + if tie_word_embeddings: + assert torch.equal( + model.get_input_embeddings().weight, model.lm_head.weight + ) + assert torch.equal( + dist_model.get_input_embeddings().weight, dist_model.lm_head.weight + ) + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + dist_logits = dist_model( + input_ids=input_ids, attention_mask=attention_mask + ).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + dist.destroy_process_group() + + +class TestQwen3TP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_qwen3_tp(self): + # Set to 2 as only 2 GPU avaialble in CI + port = get_available_port() + mp.spawn(test_qwen3_tp, nprocs=2, args=(2, self.temp_dir.name, port)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + + suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestQwen3TP)) + + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_sglang_backend/__init__.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_sglang_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7396fc88dbc77f22e7a9d9036f659a8bb57d59 --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_sglang_backend/test_sglang_backend.py @@ -0,0 +1,430 @@ +import os +import unittest + +import torch +import torch.multiprocessing as mp +from accelerate.utils import set_seed + +from specforge.distributed import init_distributed +from specforge.modeling.target.eagle3_target_model import SGLangEagle3TargetModel +from tests.utils import get_available_port + + +@torch.no_grad() +def test_dense(rank, world_size, port, tp_size): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + input_ids = torch.randint(0, 1000, (2, 256)).cuda() + attention_mask = torch.ones_like(input_ids) + loss_mask = torch.ones_like(input_ids) + + # test dense model + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + "unsloth/Llama-3.2-1B", + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.4, + enable_torch_compile=True, + enable_nccl_nvls=True, + enable_symm_mem=True, + enable_dp_attention=False, + enable_dp_lm_head=False, + enable_piecewise_cuda_graph=True, + ep_size=1, + context_length=256, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask + ) + + +@torch.no_grad() +def test_moe(rank, world_size, port, tp_size): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + input_ids = torch.randint(0, 1000, (2, 256)).cuda() + attention_mask = torch.ones_like(input_ids) + loss_mask = torch.ones_like(input_ids) + + # test moe model + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.4, + enable_torch_compile=True, + enable_nccl_nvls=True, + enable_symm_mem=True, + enable_dp_attention=True, + enable_dp_lm_head=True, + enable_piecewise_cuda_graph=True, + ep_size=2, + context_length=256, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask + ) + + +def test_vlm(rank, world_size, port, tp_size): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + # model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + image_path = os.path.join(os.path.dirname(__file__), "images", "demo.jpeg") + + # Use Qwen2.5-VL processor to prepare inputs + from qwen_vl_utils import process_vision_info + from transformers import Qwen2_5_VLProcessor + + processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + + # Create test messages with images (batch_size=2) + # Sample 1: single image + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + # Sample 2: single image (can use same or different image) + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What do you see in this picture?"}, + ], + } + ] + + # Process each sample separately to get correct format + batch_input_ids = [] + batch_attention_mask = [] + batch_pixel_values = [] + batch_image_grid_thw = [] + + for messages in [messages_1, messages_2]: + # Apply chat template + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Process vision info to get actual image data + image_inputs, video_inputs = process_vision_info(messages) + + # Process with processor + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + batch_input_ids.append(inputs["input_ids"]) + batch_attention_mask.append(inputs["attention_mask"]) + batch_pixel_values.append(inputs["pixel_values"]) + batch_image_grid_thw.append(inputs["image_grid_thw"]) + + # Debug: print shapes + if rank == 0: + print(f"[Debug] batch_input_ids shapes: {[x.shape for x in batch_input_ids]}") + print( + f"[Debug] batch_pixel_values shapes: {[x.shape for x in batch_pixel_values]}" + ) + print( + f"[Debug] batch_image_grid_thw shapes: {[x.shape for x in batch_image_grid_thw]}" + ) + print(f"[Debug] batch_image_grid_thw values: {batch_image_grid_thw}") + # Count image tokens in input_ids + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + for i, ids in enumerate(batch_input_ids): + num_img_tokens = (ids == image_token_id).sum().item() + print(f"[Debug] Sample {i}: {num_img_tokens} image tokens in input_ids") + + # Pad input_ids and attention_mask to same length + max_len = max(ids.shape[1] for ids in batch_input_ids) + padded_input_ids = [] + padded_attention_mask = [] + padded_loss_mask = [] + + for input_ids, attention_mask in zip(batch_input_ids, batch_attention_mask): + pad_len = max_len - input_ids.shape[1] + if pad_len > 0: + input_ids = torch.nn.functional.pad( + input_ids, (0, pad_len), value=processor.tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0 + ) + padded_input_ids.append(input_ids) + padded_attention_mask.append(attention_mask) + padded_loss_mask.append( + attention_mask.clone() + ) # loss_mask same as attention_mask + + # Stack into batches + input_ids = torch.cat(padded_input_ids, dim=0).cuda() + attention_mask = torch.cat(padded_attention_mask, dim=0).cuda() + loss_mask = torch.cat(padded_loss_mask, dim=0).cuda() + + # pixel_values and image_grid_thw remain as lists (one per sample) + pixel_values = torch.cat(batch_pixel_values, dim=0).cuda() + image_grid_thw = [thw.cuda() for thw in batch_image_grid_thw] + + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + model_path, + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.75, + enable_torch_compile=True, + enable_nccl_nvls=False, + enable_symm_mem=False, # Disable to avoid nccl_allocator compilation issues + enable_dp_attention=True, + enable_dp_lm_head=True, + enable_piecewise_cuda_graph=True, + context_length=4096, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_vlm=True, + ) + + if rank == 0: + # Verify output shapes + print(f"[Rank {rank}] hidden_states shape: {sgl_out.hidden_states.shape}") + print(f"[Rank {rank}] target shape: {sgl_out.target.shape}") + print(f"[Rank {rank}] input_ids shape: {sgl_out.input_ids.shape}") + + +def test_vlm_multi_batch(rank, world_size, port, tp_size): + """Test VLM with larger batch size (4 samples) and varying image counts.""" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + model_path = "Qwen/Qwen2.5-VL-32B-Instruct" + + from qwen_vl_utils import process_vision_info + from transformers import Qwen2_5_VLProcessor + + processor = Qwen2_5_VLProcessor.from_pretrained(model_path) + + image_path = os.path.join(os.path.dirname(__file__), "images", "demo.jpeg") + + # Create test messages with different configurations (batch_size=4) + # Sample 1: single image + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "Describe this image in detail."}, + ], + } + ] + + # Sample 2: single image with different prompt + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What objects can you see in this picture?"}, + ], + } + ] + + # Sample 3: single image with longer prompt + messages_3 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + { + "type": "text", + "text": "Please analyze this image and describe the main subject, background, colors, and any notable details you observe.", + }, + ], + } + ] + + # Sample 4: single image with short prompt + messages_4 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": "What is this?"}, + ], + } + ] + + all_messages = [messages_1, messages_2, messages_3, messages_4] + batch_size = len(all_messages) + + # Process each sample separately to get correct format + batch_input_ids = [] + batch_attention_mask = [] + batch_pixel_values = [] + batch_image_grid_thw = [] + + for messages in all_messages: + # Apply chat template + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Process vision info to get actual image data + image_inputs, video_inputs = process_vision_info(messages) + + # Process with processor + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + batch_input_ids.append(inputs["input_ids"]) + batch_attention_mask.append(inputs["attention_mask"]) + batch_pixel_values.append(inputs["pixel_values"]) + batch_image_grid_thw.append(inputs["image_grid_thw"]) + + # Pad input_ids and attention_mask to same length + max_len = max(ids.shape[1] for ids in batch_input_ids) + padded_input_ids = [] + padded_attention_mask = [] + padded_loss_mask = [] + + for input_ids, attention_mask in zip(batch_input_ids, batch_attention_mask): + pad_len = max_len - input_ids.shape[1] + if pad_len > 0: + input_ids = torch.nn.functional.pad( + input_ids, (0, pad_len), value=processor.tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0 + ) + padded_input_ids.append(input_ids) + padded_attention_mask.append(attention_mask) + padded_loss_mask.append( + attention_mask.clone() + ) # loss_mask same as attention_mask + + # Stack into batches + input_ids = torch.cat(padded_input_ids, dim=0).cuda() + attention_mask = torch.cat(padded_attention_mask, dim=0).cuda() + loss_mask = torch.cat(padded_loss_mask, dim=0).cuda() + + # pixel_values and image_grid_thw remain as lists (one per sample) + pixel_values = torch.cat(batch_pixel_values, dim=0).cuda() + image_grid_thw = [thw.cuda() for thw in batch_image_grid_thw] + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + model_path, + torch_dtype=torch.float16, + device="cuda", + attention_backend="fa3", + mem_fraction_static=0.4, + enable_torch_compile=True, + enable_nccl_nvls=False, + enable_symm_mem=False, + enable_dp_attention=True, + enable_dp_lm_head=True, + enable_piecewise_cuda_graph=True, + context_length=4096, + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_vlm=True, + ) + + if rank == 0: + # Verify output shapes + print(f"\n{'='*60}") + print(f"[test_vlm_multi_batch] Results:") + print(f"[Rank {rank}] hidden_states shape: {sgl_out.hidden_states.shape}") + print(f"[Rank {rank}] target shape: {sgl_out.target.shape}") + print(f"[Rank {rank}] input_ids shape: {sgl_out.input_ids.shape}") + + # Verify batch dimension matches + assert ( + sgl_out.input_ids.shape[0] == batch_size + ), f"Expected batch_size={batch_size}, got {sgl_out.input_ids.shape[0]}" + print(f"[Rank {rank}] Batch size verification: PASSED") + print(f"{'='*60}\n") + + +class TestTargetModelBackend(unittest.TestCase): + + def test_sglang_backend_with_dense(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_dense, nprocs=world_size, args=(world_size, port, 2)) + + def test_sglang_backend_with_moe(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_moe, nprocs=world_size, args=(world_size, port, 2)) + + def test_sglang_backend_with_vlm(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_vlm, nprocs=world_size, args=(world_size, port, 2)) + + def test_sglang_backend_with_vlm_multi_batch(self): + world_size = 2 + port = get_available_port() + mp.spawn(test_vlm_multi_batch, nprocs=world_size, args=(world_size, port, 2)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestTargetModelBackend)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_modeling/test_target/test_target_model_backend.py b/progress/github/SpecForge/tests/test_modeling/test_target/test_target_model_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..232a19813a1493b75dff294d7e7e2b2bb5c27aab --- /dev/null +++ b/progress/github/SpecForge/tests/test_modeling/test_target/test_target_model_backend.py @@ -0,0 +1,108 @@ +import os +import unittest + +import torch +import torch.multiprocessing as mp +from accelerate.utils import set_seed + +from specforge.distributed import init_distributed +from specforge.modeling.target.eagle3_target_model import ( + CustomEagle3TargetModel, + HFEagle3TargetModel, + SGLangEagle3TargetModel, +) +from tests.utils import get_available_port + + +@torch.no_grad() +def test_target_model_backend(rank, world_size, port, tp_size): + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + init_distributed(tp_size=tp_size) + set_seed(42) + + input_ids = torch.randint(0, 1000, (2, 256)).cuda() + attention_mask = torch.ones_like(input_ids) + loss_mask = torch.ones_like(input_ids) + + hf_target_model = HFEagle3TargetModel.from_pretrained( + "unsloth/Llama-3.2-1B", torch_dtype=torch.float16, device="cuda" + ) + hf_target_model.set_aux_hidden_states_layers() + hf_out = hf_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + del hf_target_model + + custom_target_model = CustomEagle3TargetModel.from_pretrained( + "unsloth/Llama-3.2-1B", torch_dtype=torch.float16, device="cuda" + ) + custom_target_model.set_aux_hidden_states_layers() + custom_out = custom_target_model.generate_eagle3_data( + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + del custom_target_model + + # compare weights + assert torch.allclose( + hf_out.target, custom_out.target, atol=1e-5, rtol=1e-5 + ), f"Logits are not close: \nhf: {hf_out[0] - custom_out[0]}" + assert torch.allclose( + hf_out.loss_mask, custom_out.loss_mask, atol=1e-5, rtol=1e-5 + ), f"Logits are not close: \ndiff: {hf_out[1] - custom_out[1]}" + assert torch.allclose( + hf_out.input_ids, custom_out.input_ids, atol=1e-5, rtol=1e-5 + ), f"Logits are not close: \ndiff: {hf_out[1] - custom_out[1]}" + assert torch.allclose( + hf_out.hidden_states, custom_out.hidden_states, atol=1e-5, rtol=1e-5 + ), f"Logits are not close: \ndiff: {hf_out[1] - custom_out[1]}" + + sgl_target_model = SGLangEagle3TargetModel.from_pretrained( + "unsloth/Llama-3.2-1B", torch_dtype=torch.float16, device="cuda" + ) + sgl_target_model.set_aux_hidden_states_layers() + sgl_out = sgl_target_model.generate_eagle3_data( + input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask + ) + del sgl_target_model + + assert torch.equal(hf_out.loss_mask, sgl_out.loss_mask) + assert torch.equal(hf_out.input_ids, sgl_out.input_ids) + assert torch.allclose( + hf_out.hidden_states, sgl_out.hidden_states, atol=1e-1, rtol=1e-2 + ), f"Hidden states are not close, diff: \n{(hf_out.hidden_states - sgl_out.hidden_states).abs().max()}" + assert torch.allclose( + hf_out.target, sgl_out.target.half(), atol=1e-1, rtol=1e-2 + ), f"Target are not close, diff: \n{(hf_out.target - sgl_out.target).abs().max()}" + + +class TestTargetModelBackend(unittest.TestCase): + + def test_target_model_backend_dp(self): + world_size = 2 + port = get_available_port() + mp.spawn( + test_target_model_backend, nprocs=world_size, args=(world_size, port, 1) + ) + + def test_target_model_backend_tp(self): + world_size = 2 + port = get_available_port() + mp.spawn( + test_target_model_backend, nprocs=world_size, args=(world_size, port, 2) + ) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestTargetModelBackend)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/progress/github/SpecForge/tests/test_scripts/__init__.py b/progress/github/SpecForge/tests/test_scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_scripts/test_prepare_data.py b/progress/github/SpecForge/tests/test_scripts/test_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9324a010edec2b43dc52ada77f432925684090 --- /dev/null +++ b/progress/github/SpecForge/tests/test_scripts/test_prepare_data.py @@ -0,0 +1,26 @@ +import unittest +from pathlib import Path + +from sglang.utils import execute_shell_command + +CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache") + + +class TestPrepareData(unittest.TestCase): + + def test_prepare_sharegpt(self): + sharegpt_train_path = CACHE_DIR.joinpath("dataset", "sharegpt_train.jsonl") + + if sharegpt_train_path.exists(): + # delete the file + sharegpt_train_path.unlink() + process = execute_shell_command( + "python scripts/prepare_data.py --dataset sharegpt" + ) + process.wait() + self.assertEqual(process.returncode, 0) + self.assertTrue(sharegpt_train_path.exists()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/progress/github/SpecForge/tests/test_scripts/test_regenerate_train_data.py b/progress/github/SpecForge/tests/test_scripts/test_regenerate_train_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9563ebc37d4cea8d800a90927e90b330f56245e4 --- /dev/null +++ b/progress/github/SpecForge/tests/test_scripts/test_regenerate_train_data.py @@ -0,0 +1,57 @@ +import unittest +from pathlib import Path + +from tests.utils import execute_shell_command, wait_for_server + +CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache") + + +class TestRegenerateTrainData(unittest.TestCase): + + def test_regenerate_sharegpt(self): + # prepare data + data_process = execute_shell_command( + "python scripts/prepare_data.py --dataset sharegpt" + ) + data_process.wait() + + # launch sglang + sglang_process = execute_shell_command( + """python3 -m sglang.launch_server \ + --model unsloth/Llama-3.2-1B-Instruct \ + --tp 1 \ + --cuda-graph-bs 4 \ + --dtype bfloat16 \ + --mem-frac=0.8 \ + --port 30000 + """, + disable_proxy=True, + enable_hf_mirror=True, + ) + wait_for_server(f"http://localhost:30000", disable_proxy=True) + + regeneration_process = execute_shell_command( + """python scripts/regenerate_train_data.py \ + --model unsloth/Llama-3.2-1B-Instruct \ + --concurrency 128 \ + --max-tokens 128 \ + --server-address localhost:30000 \ + --temperature 0.8 \ + --input-file-path ./cache/dataset/sharegpt_train.jsonl \ + --output-file-path ./cache/dataset/sharegpt_train_regen.jsonl \ + --num-samples 10 + """, + disable_proxy=True, + enable_hf_mirror=True, + ) + regeneration_process.wait() + self.assertEqual(regeneration_process.returncode, 0) + self.assertTrue( + CACHE_DIR.joinpath("dataset", "sharegpt_train_regen.jsonl").exists() + ) + sglang_process.terminate() + sglang_process.wait() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/progress/github/SpecForge/tests/test_scripts/test_train_eagle3.py b/progress/github/SpecForge/tests/test_scripts/test_train_eagle3.py new file mode 100644 index 0000000000000000000000000000000000000000..a543e1e0080f45e038e56ec5a5fc9483a942858e --- /dev/null +++ b/progress/github/SpecForge/tests/test_scripts/test_train_eagle3.py @@ -0,0 +1,137 @@ +import shutil +import unittest +from pathlib import Path + +from tests.utils import execute_shell_command + +CACHE_DIR = Path(__file__).parent.parent.parent.joinpath("cache") + + +def replace_in_script(script_path: Path, pattern: str, replacement: str): + with open(script_path, "r") as f: + script = f.readlines() + script = [line.replace(pattern, replacement) for line in script] + with open(script_path, "w") as f: + for line in script: + f.write(line) + + +class TestTrainEagle3(unittest.TestCase): + + def setUp(self) -> None: + # prepare data + data_process = execute_shell_command( + "python scripts/prepare_data.py --dataset sharegpt" + ) + data_process.wait() + + # modify the sccript to only train for 10 steps + # add --max-num-steps 10 to the launch command + script_path = Path(__file__).parent.parent.parent.joinpath( + "examples", "run_llama3.1_8b_eagle3_online.sh" + ) + with open(script_path, "r") as f: + script = f.readlines() + + # remove empty lines + script = [line for line in script if line.strip()] + script[-1] = script[-1].rstrip() + " --max-num-steps 10" + + # replace meta-llama/Llama-3.1-8B-Instruct with unsloth/Llama-3.2-1B-Instruct + # so that we don't need HF token for gated repo + script = [ + line.replace( + "meta-llama/Llama-3.1-8B-Instruct", "nreHieW/Llama-3.1-8B-Instruct" + ) + for line in script + ] + + # write the script back to the file + with open(script_path, "w") as f: + for line in script: + f.write(line) + + def test_online_train_eagle3_with_sglang_backend(self): + # run training + train_process = execute_shell_command( + "bash examples/run_llama3.1_8b_eagle3_online.sh 2" + ) + train_process.wait() + self.assertEqual(train_process.returncode, 0) + + def test_online_train_eagle3_with_hf_backend(self): + # replace --target-model-backend sglang with --target-model-backend hf + script_path = Path(__file__).parent.parent.parent.joinpath( + "examples", "run_llama3.1_8b_eagle3_online.sh" + ) + replace_in_script( + script_path, "--target-model-backend sglang", "--target-model-backend hf" + ) + + # run training + train_process = execute_shell_command( + "bash examples/run_llama3.1_8b_eagle3_online.sh 2" + ) + train_process.wait() + self.assertEqual(train_process.returncode, 0) + + def test_online_train_eagle3_with_custom_backend(self): + # replace --target-model-backend sglang with --target-model-backend custom + script_path = Path(__file__).parent.parent.parent.joinpath( + "examples", "run_llama3.1_8b_eagle3_online.sh" + ) + replace_in_script( + script_path, + "--target-model-backend sglang", + "--target-model-backend custom", + ) + + # run training + train_process = execute_shell_command( + "bash examples/run_llama3.1_8b_eagle3_online.sh 2" + ) + train_process.wait() + self.assertEqual(train_process.returncode, 0) + + def test_offline_train_eagle3(self): + # remove the hidden states if they exist + script_path = Path(__file__).parent.parent.parent.joinpath( + "examples", "run_llama3.1_8b_eagle3_offline.sh" + ) + replace_in_script( + script_path, + "meta-llama/Llama-3.1-8B-Instruct", + "nreHieW/Llama-3.1-8B-Instruct", + ) + replace_in_script( + script_path, + "--batch-size 32", + "--batch-size 5", + ) + replace_in_script( + script_path, + "scripts/prepare_hidden_states.py", + "scripts/prepare_hidden_states.py --num-samples 10", + ) + replace_in_script( + script_path, + "$ROOT_DIR/scripts/train_eagle3.py", + "$ROOT_DIR/scripts/train_eagle3.py --max-num-steps 2", + ) + + hidden_states_path = Path(__file__).parent.parent.parent.joinpath( + "cache", "hidden_states", "sharegpt_train_Llama-3.1-8B-Instruct" + ) + if hidden_states_path.exists(): + # delete the directory + shutil.rmtree(hidden_states_path) + + training_process = execute_shell_command( + "bash examples/run_llama3.1_8b_eagle3_offline.sh 2", + ) + training_process.wait() + self.assertEqual(training_process.returncode, 0) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/progress/github/SpecForge/tests/test_utils/__init__.py b/progress/github/SpecForge/tests/test_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_utils/test_flex_attention.py b/progress/github/SpecForge/tests/test_utils/test_flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bc7ce3b809d1f93abe5788519982383b71dd48 --- /dev/null +++ b/progress/github/SpecForge/tests/test_utils/test_flex_attention.py @@ -0,0 +1,292 @@ +import unittest + +import torch +import torch._dynamo as dynamo +from transformers import LlamaConfig +from transformers.cache_utils import DynamicCache + +from specforge.modeling.draft.flex_attention import ( + compile_friendly_create_block_mask, + compile_friendly_flex_attention, + generate_eagle3_mask, +) +from specforge.modeling.draft.llama3_eagle import ( + LlamaAttention, + LlamaFlexAttention, + prepare_decoder_attention_mask, +) +from specforge.utils import padding + +from .utils import norm_tensor + +dynamo.config.recompile_limit = 64 +TTT_LENGTH = 7 +torch.manual_seed(0) + + +class TestFlexAttention(unittest.TestCase): + + def setUp(self): + torch.manual_seed(0) + self.config_dict = { + "hidden_size": 128, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "vocab_size": 32000, + "intermediate_size": 688, + "hidden_act": "silu", + "num_hidden_layers": 1, + "torch_dtype": "float32", + } + self.config = LlamaConfig(**self.config_dict) + + self.seq_lengths = [128, 200, 256, 300, 512, 800, 1024, 2048] + self.dtype = torch.float32 + + def test_forward_pass_comparison(self): + """Test forward pass comparison between LlamaAttention and LlamaFlexAttention.""" + for seq_len in self.seq_lengths: + with self.subTest(seq_len=seq_len): + self._test_forward_pass_comparison_for_seq_len(seq_len) + + def _test_forward_pass_comparison_for_seq_len(self, seq_len): + """Helper method to test forward pass comparison for a specific sequence length.""" + attention = LlamaAttention(self.config).to("cuda").to(self.dtype) + flex_attention = LlamaFlexAttention(self.config).to("cuda").to(self.dtype) + + # Ensure same weights + with torch.no_grad(): + flex_attention.q_proj.weight.copy_(attention.q_proj.weight) + flex_attention.k_proj.weight.copy_(attention.k_proj.weight) + flex_attention.v_proj.weight.copy_(attention.v_proj.weight) + flex_attention.o_proj.weight.copy_(attention.o_proj.weight) + + attention.eval() + flex_attention.eval() + batch_size = 2 + hidden_size = self.config.hidden_size * 2 + + ############### Attention Inputs ############## + + position_ids = ( + torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to("cuda") + ) + cache_hidden = [[], []] # [cache_k, cache_v] + attention_mask = torch.ones(batch_size, seq_len, dtype=self.dtype).to("cuda") + # Simulate one item in the batch is masked and not taking a full block. + padding_start_index = seq_len - min( + 200, seq_len // 3 + ) # Adjust padding based on seq_len + attention_mask[1, padding_start_index:] = False + input_embeds = norm_tensor( + (batch_size, seq_len, self.config.hidden_size), + device="cuda", + dtype=self.dtype, + ) + decoder_attention_mask = prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_len), + inputs_embeds=input_embeds, + past_key_values_length=0, + ) + hidden_states_list = [] + flex_hidden_states_list = [] + for idx in range(TTT_LENGTH): + hidden_states = norm_tensor( + (batch_size, seq_len, hidden_size), device="cuda", dtype=self.dtype + ) + flex_hidden_states = hidden_states.clone().detach() + hidden_states_list.append(hidden_states) + flex_hidden_states_list.append(flex_hidden_states) + + ############### Flex Attention Inputs ############## + flex_position_ids = position_ids.clone() + past_key_values = DynamicCache() + for idx in range(TTT_LENGTH): + is_last = idx == TTT_LENGTH - 1 + with torch.no_grad(): + output = attention( + hidden_states=hidden_states_list[idx], + attention_mask=decoder_attention_mask, + position_ids=position_ids, + cache_hidden=cache_hidden, + output_attentions=False, + use_cache=True, + ) + with torch.no_grad(): + output_flex = flex_attention( + hidden_states=flex_hidden_states_list[idx], + attention_mask=attention_mask, + position_ids=flex_position_ids, + past_key_values=past_key_values, + ) + torch.testing.assert_close( + output[0][: -1 - idx], output_flex[0][: -1 - idx], atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + output[1][: padding_start_index - idx], + output_flex[1][: padding_start_index - idx], + atol=1e-2, + rtol=1e-2, + ) + + # Check output shape + expected_output_shape = (batch_size, seq_len, self.config.hidden_size) + self.assertEqual(output_flex.shape, expected_output_shape) + # Check output is not NaN or Inf + self.assertFalse(torch.isnan(output_flex).any()) + self.assertFalse(torch.isinf(output_flex).any()) + + def test_backward_pass_gradient_comparison(self): + """Test backward pass comparing gradients between LlamaAttention and LlamaFlexAttention.""" + for seq_len in self.seq_lengths: + with self.subTest(seq_len=seq_len): + self._test_backward_pass_gradient_comparison_for_seq_len(seq_len) + + def _test_backward_pass_gradient_comparison_for_seq_len(self, seq_len): + """Helper method to test backward pass gradient comparison for a specific sequence length.""" + attention = LlamaAttention(self.config).to("cuda").to(self.dtype) + flex_attention = LlamaFlexAttention(self.config).to("cuda").to(self.dtype) + + # Ensure same weights + with torch.no_grad(): + flex_attention.q_proj.weight.copy_(attention.q_proj.weight) + flex_attention.k_proj.weight.copy_(attention.k_proj.weight) + flex_attention.v_proj.weight.copy_(attention.v_proj.weight) + flex_attention.o_proj.weight.copy_(attention.o_proj.weight) + + batch_size = 2 + hidden_size = self.config.hidden_size * 2 + + ############### Attention Inputs ############## + position_ids = ( + torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to("cuda") + ) + cache_hidden = [[], []] # [cache_k, cache_v] + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool).to("cuda") + # Simulate one item in the batch is masked and not taking a full block. + # padding_start_index = seq_len - 50 + # attention_mask[1, padding_start_index:] = False + input_embeds = norm_tensor( + (batch_size, seq_len, self.config.hidden_size), + device="cuda", + dtype=self.dtype, + ) + decoder_attention_mask = prepare_decoder_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_len), + inputs_embeds=input_embeds, + past_key_values_length=0, + ) + + ############### Flex Attention Inputs ############## + flex_position_ids = position_ids.clone() + ttt_length = TTT_LENGTH + past_key_values = DynamicCache() + loss_mask = torch.ones( + batch_size, seq_len, dtype=self.dtype, requires_grad=False + ).to("cuda") + + # Create input tensors that require gradients + loss_list = [] + loss_flex_list = [] + hidden_states_list = [] + flex_hidden_states_list = [] + for idx in range(TTT_LENGTH): + hidden_states = norm_tensor( + (batch_size, seq_len, hidden_size), device="cuda", dtype=self.dtype + ) + flex_hidden_states = hidden_states.clone().detach() + hidden_states_list.append(hidden_states) + flex_hidden_states_list.append(flex_hidden_states) + + for idx in range(TTT_LENGTH): + is_last = idx == TTT_LENGTH - 1 + output = attention( + hidden_states=hidden_states_list[idx], + attention_mask=decoder_attention_mask, + position_ids=position_ids, + cache_hidden=cache_hidden, + output_attentions=False, + use_cache=True, + ) + output_flex = flex_attention( + hidden_states=flex_hidden_states_list[idx], + attention_mask=attention_mask, + position_ids=flex_position_ids, + past_key_values=past_key_values, + ) + # Apply loss mask on calculation over batch + loss = (output * loss_mask[..., None]).sum().mean() + loss_flex = (output_flex * loss_mask[..., None]).sum().mean() + torch.testing.assert_close(loss, loss_flex, atol=1e-2, rtol=1e-2) + loss_list.append(loss) + loss_flex_list.append(loss_flex) + # Compare gradients + + if not is_last: + # Step 5.7: we need to update the loss mask + loss_mask = padding(loss_mask, left=False) + mean_loss = sum(loss_list) / len(loss_list) + mean_loss_flex = sum(loss_flex_list) / len(loss_flex_list) + mean_loss.backward() + mean_loss_flex.backward() + projections = ["q_proj", "k_proj", "v_proj", "o_proj"] + for proj_name in projections: + torch.testing.assert_close( + getattr(attention, proj_name).weight.grad, + getattr(flex_attention, proj_name).weight.grad, + atol=1e-2, + rtol=1e-2, + ) + + +class TestEagle3FlexMask(unittest.TestCase): + + def test_eagle3_flex_mask(self): + B = 1 + H = 1 + S = 128 * 8 + D = 128 + Q_LEN = S + KV_LEN = S * 3 + lck = 128 * 2 + data_type = torch.bfloat16 + query = norm_tensor((B, H, S, D), device="cuda", dtype=data_type) + key_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type) + value_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type) + seq_lengths = torch.tensor([S], device="cuda", dtype=torch.int32) + seq_lengths -= lck + block_mask = compile_friendly_create_block_mask( + mask_mod=generate_eagle3_mask( + seq_lengths=seq_lengths, Q_LEN=Q_LEN, KV_LEN=KV_LEN, lck=lck + ), + B=1, + H=1, + Q_LEN=Q_LEN, + KV_LEN=KV_LEN, + device=query.device, + ) + # fmt: off + expected_mask = torch.tensor([[[ + [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]]], dtype=torch.int32).to(query.device) + # fmt: on + dense_mask = block_mask.to_dense() + assert torch.allclose(dense_mask, expected_mask) + output = compile_friendly_flex_attention( + query, key_cache, value_cache, block_mask=block_mask + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/progress/github/SpecForge/tests/test_utils/test_loss.py b/progress/github/SpecForge/tests/test_utils/test_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..56dc47ca513fa205f335bbe74739a25f2dac2c14 --- /dev/null +++ b/progress/github/SpecForge/tests/test_utils/test_loss.py @@ -0,0 +1,87 @@ +import unittest + +import torch + +from specforge.core.loss import LogSoftmaxLoss, _compute_loss + +from .utils import norm_tensor + + +class TestLogSoftmaxLoss(unittest.TestCase): + + TTT_LENGTH = 7 + + def _test_loss_and_gradient_calculation(self, B, T, V): + if not torch.cuda.is_available(): + device = "cpu" + else: + device = "cuda" + + logits = norm_tensor((B, T, V), device, torch.float32) + logits2 = logits.clone().detach().requires_grad_(True) + target = norm_tensor((B, T, V), device, torch.float32) + position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) + + output1 = LogSoftmaxLoss.apply(logits, target, position_mask) + output2 = _compute_loss(logits2, target, position_mask) + torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) + + output1.backward() + output2.backward() + torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4) + + def test_loss(self): + B = [1, 2, 4] + T = [1024, 2048, 4096, 6000] + V = [4096, 8192, 10000] + for b in B: + for t in T: + for v in V: + self._test_loss_and_gradient_calculation(b, t, v) + + def test_ttt_loss_accumulation(self): + if not torch.cuda.is_available(): + device = "cpu" + else: + device = "cuda" + + B, T, V = 1, 1024, 3200 + plosses = [] + plosses_compare = [] + logits_list = [ + norm_tensor((B, T, V), device, torch.float32) + for _ in range(self.TTT_LENGTH) + ] + logits_list_copy = [ + logits.clone().detach().requires_grad_(True) for logits in logits_list + ] + for i in range(self.TTT_LENGTH): + logits = logits_list[i] + logits2 = logits_list_copy[i] + target = norm_tensor((B, T, V), device, torch.float32) + position_mask = torch.randint( + 0, 2, (B, T, 1), dtype=torch.bool, device=device + ) + + output1 = LogSoftmaxLoss.apply(logits, target, position_mask) + output2 = _compute_loss(logits2, target, position_mask) + torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) + plosses.append(output1) + plosses_compare.append(output2) + + ploss_weight = [0.8**i for i in range(len(plosses))] + ploss = ( + sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) + / self.TTT_LENGTH + ) + ploss_compare = ( + sum([ploss_weight[i] * plosses_compare[i] for i in range(len(plosses))]) + / self.TTT_LENGTH + ) + torch.testing.assert_close(ploss, ploss_compare, rtol=1e-4, atol=1e-4) + ploss.backward() + ploss_compare.backward() + for i in range(self.TTT_LENGTH): + torch.testing.assert_close( + logits_list[i].grad, logits_list_copy[i].grad, rtol=1e-4, atol=1e-4 + ) diff --git a/progress/github/SpecForge/tests/test_utils/utils.py b/progress/github/SpecForge/tests/test_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5688218f5ac82961ede231262ab642f30ceaf5a5 --- /dev/null +++ b/progress/github/SpecForge/tests/test_utils/utils.py @@ -0,0 +1,8 @@ +import torch +import torch.nn.init + + +def norm_tensor(shape, device, dtype, std=0.02): + t = torch.empty(shape, device=device, dtype=dtype, requires_grad=True) + torch.nn.init.trunc_normal_(t, mean=0.0, std=std) + return t diff --git a/progress/github/SpecForge/tests/test_vlm/__init__.py b/progress/github/SpecForge/tests/test_vlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/progress/github/SpecForge/tests/test_vlm/test_qwenvl_loss_mask.py b/progress/github/SpecForge/tests/test_vlm/test_qwenvl_loss_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..74523861fde945e15ae3a098fec0bf65f90eb18d --- /dev/null +++ b/progress/github/SpecForge/tests/test_vlm/test_qwenvl_loss_mask.py @@ -0,0 +1,64 @@ +from transformers import AutoProcessor + +from specforge.data.preprocessing import preprocess_vlm_conversations +from specforge.data.template import TEMPLATE_REGISTRY + +model_path = "Qwen/Qwen2.5-VL-7B-Instruct" +processor = AutoProcessor.from_pretrained(model_path) +# ANSI color codes +RED = "\033[91m" +GREEN = "\033[92m" +RESET = "\033[0m" + + +def test_preprocess_vlm_conversations(): + conversations = [ + {"role": "user", "content": "what is in the image?"}, + {"role": "assistant", "content": "This is an image of a cat."}, + ] + + examples = { + "id": ["example1"], + "image": [ + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + ], + "conversations": [conversations], + } + # examples = [examples] # Wrap in a list to match expected input format + chat_template = TEMPLATE_REGISTRY.get("qwen2-vl") + processed = preprocess_vlm_conversations(processor, examples, chat_template, 4096) + + input_ids = processed["input_ids"][0] + print(f"Loss mask sum: {processed['loss_mask'][0].sum()}") + loss_mask = processed["loss_mask"][0].squeeze(0).tolist() + input_ids = input_ids.squeeze(0) + current_mask = input_ids[0] + current_ids = [] + + for i in range(len(input_ids)): + # for i in range(input_ids.shape[-1]): + if current_mask == loss_mask[i]: + current_ids.append(input_ids[i]) + else: + decoded_text = processor.tokenizer.decode( + current_ids, skip_special_tokens=False + ) + if current_mask == 0: + print(f"{RED}{decoded_text}{RESET}", end="") + else: + print(f"{GREEN}{decoded_text}{RESET}", end="") + current_ids = [input_ids[i]] + current_mask = loss_mask[i] + + print( + f"{GREEN}{processor.tokenizer.decode(current_ids, skip_special_tokens=False)}{RESET}" + ) + + print() + print(f"input_ids shape: {processed['input_ids'][0].shape}") + print(f"loss_mask shape: {processed['loss_mask'][0].shape}") + # print(f"hidden_state shape: {processed['hidden_state'].shape}") + + +if __name__ == "__main__": + test_preprocess_vlm_conversations()